Git Product home page Git Product logo

reformer_lm's Introduction

Reformer

a Pytorch implementation of the Reformer Network (https://openreview.net/pdf?id=rkgNKkHtvB)

Much of this code base is loosely translated from the jax implementation found here from Google: https://github.com/google/trax/blob/master/trax/models/research/reformer.py

How to use

All of the hard work has been taken care of, all you need to do is instantiate the model!

from reformer_lm.reformer_lm import ReformerLM
import torch

test = torch.rand((4, 4, 64))
model = ReformerLM(
    vocab_size=300000,
    d_in=test.shape[-2],
    d_out=test.shape[-1],
    n_layers=6,
    n_heads=1,
    attn_k=test.shape[-1],
    attn_v=test.shape[-1],
)

output = model(test)
print(output)

This model is still in testing, and will therefore continue to see updates. PRs are welcomed! Feel free to take advantage of the Docker container for development. I have been working in notebooks to test code with the original paper, and then I refactor my code back into the package

paypal

reformer_lm's People

Contributors

zbloss avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

reformer_lm's Issues

How to calculate the loss

could you tell me what the ouput means when I did like this

model = ReformerLM(
            vocab_size=model_config['vocab_size'],
            d_in=model_config['n_seq_len'],
            d_out=model_config['n_embd'],
            n_layers=model_config['n_layers'],
            n_heads=model_config['n_heads'],
            attn_k=model_config['attn_k'],
            attn_v=model_config['attn_v'],
        ) 
inputs:  [batch_size,seq_len,dim]=(4,150,100)
output =  model(inputs)

thanks

Random projection?

Thanks for your sharing model!
However, it seems that you have directly split inputs into (#chunks), while the original paper used Random Projection, and then take the argmax as its hash value to determine the chunk?

Additionally, have you add the LSH part into the project? Since I am confused how to implement it in code, I would appreciate it very much if you can give some guidance!

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.