Git Product home page Git Product logo

kyegomez / shallowff Goto Github PK

View Code? Open in Web Editor NEW
7.0 3.0 1.0 37.08 MB

Zeta implemantion of "Rethinking Attention: Exploring Shallow Feed-Forward Neural Networks as an Alternative to Attention Layers in Transformers"

Home Page: https://discord.gg/Yx5y5VBahs

License: MIT License

Makefile 5.43% Python 94.57%
artificial-intelligence attention attention-is-all-you-need attention-mechanism attention-mechanisms feedforward transformer transformer-encoder transformer-models transformers-models

shallowff's Introduction

Multi-Modality

ALR Transformer

ALR Transformer that replaces the original transformer implementation of an joint encoder + decoder block with a feedforward/alr block with a decoder block

Install

pip install alr-transformer

Usage

import torch
from alr_transformer import ALRTransformer

x = torch.randint(0, 100000, (1, 2048))

model = ALRTransformer(
    dim = 512,
    depth = 6,
    num_tokens = 100000,
    dim_head = 64,
    heads = 8,
    ff_mult = 4
)

out = model(x)
print(out)
print(out.shape)

Train

  • First git clone the repo then download and then run the following
python3 train.py

Citation

@misc{bozic2023rethinking,
    title={Rethinking Attention: Exploring Shallow Feed-Forward Neural Networks as an Alternative to Attention Layers in Transformers}, 
    author={Vukasin Bozic and Danilo Dordervic and Daniele Coppola and Joseph Thommes},
    year={2023},
    eprint={2311.10642},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

shallowff's People

Contributors

kyegomez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

shallowff's Issues

[BUG] torch.cuda.OutOfMemoryError: CUDA out of memory

v@v-System-Product-Name:~/ShallowFF$ /bin/python3 /home/v/ShallowFF/train.py
/home/v/ShallowFF/train.py:52: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
training: 0%| | 0/100000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/v/ShallowFF/train.py", line 87, in
loss = model(next(train_loader))
File "/home/v/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/v/ShallowFF/alr_transformer/at.py", line 82, in forward
logits = self.net(x_inp, **kwargs)
File "/home/v/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/v/ShallowFF/alr_transformer/model.py", line 203, in forward
x = self.transformer(x)
File "/home/v/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/v/ShallowFF/alr_transformer/model.py", line 178, in forward
x = block(x) + x
File "/home/v/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/v/ShallowFF/alr_transformer/model.py", line 138, in forward
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.69 GiB total capacity; 1.40 GiB already allocated; 492.12 MiB free; 1.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.