Git Product home page Git Product logo

g-mlp-pytorch's Introduction

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from torch import nn
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256,
    circulant_matrix = True,      # use circulant weight matrix for linear increase in parameters in respect to sequence length
    act = nn.Tanh()               # activation for spatial gate (defaults to identity)
)

x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Non-square images and patch sizes

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = (256, 128),
    patch_size = (16, 8),
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)

Experimental

A independent researcher proposes using a multi-headed approach for gMLPs in a blogpost on Zhihu. To do so, just set heads to be greater than 1

import torch
from torch import nn
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256,
    causal = True,
    circulant_matrix = True,
    heads = 4 # 4 heads
)

x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = aug,
    year         = 2021,
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578%7D
}

g-mlp-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

g-mlp-pytorch's Issues

pre-trained model

Thanks for your amazing work. Has anyone trained the LM already? In other words is a pre-trained gMLP available?

Parameter count doesnt line up with paper

Just a note (and correct me if I misunderstood the paper) -

The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult.
Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

Then param count is roughly 5.5 M params.

Custom image sizes?

Hi,
Thanks for your great (and very fast) contribution!
I was wondering if you could help me figure out how to apply this to a different image size?
It's not really an image, but rather a 2D dimensional tensor of 4096X100.

I saw that I can change the number of channels, so I could just set channels to be 1.
But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

Thanks.

Potentially missing the high way pass

Hello,

Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

Thank you

Don't you think this is more legible?

`
class SpatialGatingUnit(nn.Module):
def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
super().init()
dim_out = dim // 2
self.causal = causal

    self.norm = nn.LayerNorm(dim_out)
    #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)

    self.dim_seq = dim_seq
    self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
    self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####

    self.act = act

    init_eps /= dim_seq
    #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
    #nn.init.constant_(self.proj.bias, 1.)

def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
    device, n = x.device, x.shape[1]

    res, gate = x.chunk(2, dim = -1)
    gate = self.norm(gate)

    weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len

    if self.causal:
        weight.unsqueeze(-1) # TODO
        weight, bias = weight[:n, :n], bias[:n]
        mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
        weight = weight.masked_fill(mask[..., None], 0.)
        weight.squeeze(-1)# TODO

    gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b

    #gate = F.conv1d(gate, weight, bias)   # WZ + b

    if exists(gate_res):
        gate = gate + gate_res

    return self.act(gate) * res

`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.