Git Product home page Git Product logo

linear-attention-transformer's Introduction

Linear Attention Transformer

PyPI version

A fully featured Transformer that mixes (QKᵀ)V local attention with Q(KᵀV) global attention (scales linearly with respect to sequence length) for efficient long-range language modeling.

Install

$ pip install linear-attention-transformer

Usage

Language model

import torch
from linear_attention_transformer import LinearAttentionTransformerLM

model = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    causal = True,                  # auto-regressive or not
    ff_dropout = 0.1,               # dropout for feedforward
    attn_layer_dropout = 0.1,       # dropout right after self-attention layer
    attn_dropout = 0.1,             # dropout post-attention
    emb_dim = 128,                  # embedding factorization, to save on memory
    dim_head = 128,                 # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    blindspot_size = 64,            # this gives the q(kv) attention a blindspot of 64 tokens back in the causal case, but gives back an order of magnitude return in memory savings. should be paired with local attention of at least a window size of this setting. setting this to 1 will allow for full q(kv) attention of past
    n_local_attn_heads = 4,         # number of local attention heads for (qk)v attention. this can be a tuple specifying the exact number of local attention heads at that depth
    local_attn_window_size = 128,   # receptive field of the local attention
    reversible = True,              # use reversible nets, from Reformer paper
    ff_chunks = 2,                  # feedforward chunking, from Reformer paper
    ff_glu = True,                  # use GLU variant for feedforward
    attend_axially = False,         # will fold the sequence by the local attention window size, and do an extra strided attention followed by a feedforward with the cheap q(kv) attention
    shift_tokens = True             # add single token shifting, for great improved convergence
).cuda()

x = torch.randint(0, 20000, (1, 8192)).cuda()
model(x) # (1, 8192, 512)

Transformer

import torch
from linear_attention_transformer import LinearAttentionTransformer

model = LinearAttentionTransformer(
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    n_local_attn_heads = 4
).cuda()

x = torch.randn(1, 8192, 512).cuda()
model(x) # (1, 8192, 512)

Encoder / decoder

import torch
from linear_attention_transformer import LinearAttentionTransformerLM

enc = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    reversible = True,
    n_local_attn_heads = 4,
    return_embeddings = True
).cuda()

dec = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    causal = True,
    max_seq_len = 4096,
    reversible = True,
    receives_context = True,
    n_local_attn_heads = 4
).cuda()

src = torch.randint(0, 20000, (1, 4096)).cuda()
src_mask = torch.ones_like(src).bool().cuda()

tgt = torch.randint(0, 20000, (1, 4096)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

context = enc(src, input_mask = src_mask)
logits = dec(tgt, context = context, input_mask = tgt_mask, context_mask = src_mask)

Linformer

Linformer is another variant of attention with linear complexity championed by Facebook AI. It only works with non-autoregressive models of a fixed sequence length. If your problem satisfies that criteria, you may choose to try it out.

from linear_attention_transformer import LinearAttentionTransformerLM, LinformerSettings

settings = LinformerSettings(k = 256)

enc = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    linformer_settings = settings
).cuda()

You can also used Linformer for the contextual attention layer, if the contextual keys are of a fixed sequence length.

from linear_attention_transformer import LinearAttentionTransformerLM, LinformerContextSettings

settings = LinformerContextSettings(
  seq_len = 2048,
  k = 256
)

dec = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    causal = True,
    context_linformer_settings = settings,
    receives_context = True
).cuda()

Images

This repository also contains a concise implementation of this efficient attention for images

import torch
from linear_attention_transformer.images import ImageLinearAttention

attn =ImageLinearAttention(
  chan = 32,
  heads = 8,
  key_dim = 64       # can be decreased to 32 for more memory savings
)

img = torch.randn(1, 32, 256, 256)
attn(img) # (1, 32, 256, 256)

Citations

@inproceedings{katharopoulos-et-al-2020,
  author    = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
  title     = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
  booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
  year      = {2020},
  url       = {https://arxiv.org/abs/2006.16236}
}
@article{shen2019efficient,
  author    = {Zhuoran Shen and
               Mingyuan Zhang and
               Haiyu Zhao and
               Shuai Yi and
               Hongsheng Li},
  title     = {Efficient Attention: Attention with Linear Complexities},
  journal   = {CoRR},
  volume    = {abs/1812.01243},
  year      = {2018},
  url       = {http://arxiv.org/abs/1812.01243}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{wang2020linformer,
    title   = {Linformer: Self-Attention with Linear Complexity},
    author  = {Sinong Wang and Belinda Z. Li and Madian Khabsa and Han Fang and Hao Ma},
    year    = {2020},
    eprint  = {2006.04768}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@techreport{zhuiyiroformer,
    title   = {RoFormer: Transformer with Rotary Position Embeddings - ZhuiyiAI},
    author  = {Jianlin Su},
    year    = {2021},
    url     = "https://github.com/ZhuiyiTechnology/roformer",
}

linear-attention-transformer's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

linear-attention-transformer's Issues

Why dim != dim_head * heads?

Dear developer,
In your use case of LinearAttentionTransformerLM, dim != dim_head * heads. I am a little bit confused about that. Is that an algorithm feature?

Questions on the implementation of a linear variant and reference

In the setup that is not Linformer:

if not exists(linformer_settings):
attn = SelfAttention(dim, heads, causal, dim_head = dim_head, blindspot_size = blindspot_size, n_local_attn_heads = local_heads, local_attn_window_size = local_attn_window_size, dropout = attn_layer_dropout, attn_dropout= attn_dropout)

The SelfAttention function uses

def linear_attn(q, k, v, kv_mask = None):
    dim = q.shape[-1]

    if exists(kv_mask):
        mask_value = max_neg_value(q)
        mask = kv_mask[:, None, :, None]
        k = k.masked_fill_(~mask, mask_value)
        v = v.masked_fill_(~mask, 0.)
        del mask

    q = q.softmax(dim=-1)
    k = k.softmax(dim=-2)

    q = q * dim ** -0.5

    context = einsum('bhnd,bhne->bhde', k, v)
    attn = einsum('bhnd,bhde->bhne', q, context)
    return attn.reshape(*q.shape)

May I know which paper exactly is this linear attention implementing? @lucidrains
Also I have a noob question: is torch.einsum more efficient than matmul, can't we just use Softmax(Q)( Softmax(K^T) V)?
Greatly appreciated.

-Shuhao

EDIT: NVM, figured out, thanks.

Challenge in replacing SelfAttention with ImageLinearAttention in Vision Transformer

When I am replacing ImageLinearAttention with SelfAttention in Vision Transformer, with the code as follows, I get a RuntimeError. The code for ImageLinearAttention is from https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/images.py except I removed number of channels as you see in commented code.

class ImageLinearAttention(nn.Module):
    def __init__(self, chan, chan_out = None, kernel_size = 1, padding = 0, stride = 1, key_dim = 64, value_dim = 64, heads = 8, norm_queries = True):
        super().__init__()
        self.chan = chan
        chan_out = chan if chan_out is None else chan_out

        self.key_dim = key_dim
        self.value_dim = value_dim
        self.heads = heads

        self.norm_queries = norm_queries

        conv_kwargs = {'padding': padding, 'stride': stride}
        self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
        print('value dim: ', value_dim)
        print('chan out: ', chan_out)
        print('kernel_size: ', kernel_size)
        out_conv_kwargs = {'padding': padding}
        print('out_conv_kwargs: ', out_conv_kwargs)
        print('in_chan: ', value_dim * heads)
        self.to_out = nn.Conv2d(value_dim * heads, chan_out, kernel_size, **out_conv_kwargs)

    def forward(self, x, context = None):
        print('x.shape: ', x.shape)
        print('*x.shape is: ', *x.shape)
        print('heads: ', self.heads)
        #b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
        b, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
        q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
        q, k = map(lambda x: x * (self.key_dim ** -0.25), (q, k))
        
        if context is not None:
            #context = context.reshape(b, c, 1, -1)
            context = context.reshape(b, 1, -1)
            ck, cv = self.to_k(context), self.to_v(context)
            ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
            k = torch.cat((k, ck), dim=3)
            v = torch.cat((v, cv), dim=3)

        k = k.softmax(dim=-1)

        if self.norm_queries:
            q = q.softmax(dim=-2)

        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhdn,bhde->bhen', q, context)
        out = out.reshape(b, -1, h, w)
        out = self.to_out(out)
        return out

Error is:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead

Also, my data fed to transformer is of size torch.Size([1983, 512]) and my batch size is 1.
Full log is:

$ bash scripts/train.sh 
train: True test: False cam: False
preparing datasets and dataloaders......
total_train_num:  176
creating models......
n_class:  2
in_dim:  512
value dim:  64
chan out:  512
kernel_size:  1
out_conv_kwargs:  {'padding': 0}
in_chan:  768
in_dim:  512
value dim:  64
chan out:  512
kernel_size:  1
out_conv_kwargs:  {'padding': 0}
in_chan:  768

=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
torch.Size([1983, 512])
features size:  torch.Size([1983, 512])
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:154: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
max_feature_num:  1983
batch feature size:  torch.Size([1, 1983, 512])
x.shape:  torch.Size([1, 1984, 512])
*x.shape is:  1 1984 512
heads:  12
Traceback (most recent call last):
  File "main.py", line 148, in <module>
    preds,labels,loss = trainer.train(sample_batched, model)
  File "/SeaExp/mona/research/code/cc/helper.py", line 71, in train
    pred,labels,loss = model.forward(feats, labels, masks)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/Transformer.py", line 31, in forward
    out = self.transformer(X)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 262, in forward
    feat = self.transformer(emb)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 206, in forward
    out = layer(out)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 174, in forward
    out = self.attn(out)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 92, in forward
    q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead

The original SelfAttention code is:

class SelfAttention(nn.Module):
    def __init__(self, in_dim, heads=8, dropout_rate=0.1):
        super(SelfAttention, self).__init__()
        self.heads = heads
        self.head_dim = in_dim // heads
        self.scale = self.head_dim ** 0.5
        
        self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,))

        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None

    def forward(self, x):
        b, n, _ = x.shape

        q = self.query(x, dims=([2], [0]))
        k = self.key(x, dims=([2], [0]))
        v = self.value(x, dims=([2], [0]))

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        out = torch.matmul(attn_weights, v)
        out = out.permute(0, 2, 1, 3)

        out = self.out(out, dims=([2, 3], [0, 1]))

        return out

How can I fix this error? I am calling the ImageSelfAttention as following in the Encoder block of the Vision Transformer:

class EncoderBlock(nn.Module):
    def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1):
        super(EncoderBlock, self).__init__()

        self.norm1 = nn.LayerNorm(in_dim)
        #self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
        ## note Mona: not sure if I am correctly passing the params
        # what about attn_dropout_rate=0.1
        ## I don't know 
        print('in_dim: ', in_dim) 
        self.attn = ImageLinearAttention(chan=in_dim, heads=num_heads, key_dim=32)
        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None
        self.norm2 = nn.LayerNorm(in_dim)
        self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate)

    def forward(self, x):
        residual = x
        out = self.norm1(x)
        out = self.attn(out)
        if self.dropout:
            out = self.dropout(out)
        out += residual
        residual = out

        out = self.norm2(out)
        out = self.mlp(out)
        out += residual
        return out

Tooooo many functions added, but no annotations

Dear author @lucidrains ,

This is really an impressive work. Scaling functions from many papers are added into a single project. However, they're no clear annotations, which makes it difficult to understand which is which, and why adding each of these functions. Is it possible that you add the annotations for these functions? And how's the improvement from each function?

Below are the functions included based on my own observation:

  1. Support multiple linear transformers: Linformer, Reformer, Efficient Attention, Longformer
  2. Support encoder, decoder, transformer
  3. Support reversible translation
  4. Support positional embeddings: rotary embedding, axial pos embedding, normal absolute positional embedding
  5. Support causal, non-causal
  6. Support global, local attn heads, in ETC

Best regards!

ImageLinearAttention showcase

Could you please show how you make use of ImageLinearAttention for image classification in combination with ViT? do you replace it with SelfAttention or do you use besides it? Any representative example is really appreciated. I want to use it in ViT for images.

class EncoderBlock(nn.Module):
    def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1):
        super(EncoderBlock, self).__init__()

        self.norm1 = nn.LayerNorm(in_dim)
        #self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
        ## note: not sure how exactly I pass the params 
        self.attn = ImageLinearAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
        ## rest of code

Loss returns Nan

Loss returns Nan

Some of my settings
causal=true
blindspot_size=1
n_local_attn_heads
ff_chunks=1
reversible=false
use_axial_pos_emb=false

How to perform training?

Can someone please share a snippet of code on how to train the LinearAttentionTransformer on new tabular data?

Scaling factors

Hi!

I think the scaling factors in the linear attention are wrong.

You have

q = q.softmax(dim=-1)
q = q * dim ** -0.5

k = k.softmax(dim=-2)

In the paper, both Q and K are scaled, however each with sqrt(N), where N is the usual scaling factor - sqrt(D_head). Shouldn't it be:

# First apply softmax, then scale - Equations 3 and 4 from the paper.
q = q.softmax(dim=-1)
k = k.softmax(dim=-2)

q = q * dim ** -0.25 # it's 1/sqrt(sqrt(d_head))
k = k * dim ** -0.25 # same.

.

.

Autopadder doesn't work with LinearAttentionTransformer

Is Autopadder supposed to work with Linformer?

If I try this:

import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder

model =  Autopadder(LinearAttentionTransformer(
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    n_local_attn_heads = 4
)).cuda()

x = torch.randn(1, 8191, 512).cuda()
model(x) # (1, 8191, 512)

I get this:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-5a45a93503d7> in <module>
     12 
     13 x = torch.randn(1, 8191, 512).cuda()
---> 14 model(x) # (1, 8192, 512)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
     53             kwargs.update(input_mask=new_mask)
     54 
---> 55         out = self.net(x, **kwargs)
     56 
     57         output_slice = slice(0, t) if not self.pad_left else slice(padding, None)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
    354 
    355     def forward(self, x, **kwargs):
--> 356         return self.layers(x, **kwargs)
    357 
    358 class LinearAttentionTransformerLM(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
    147 
    148         for (f, g), (f_args, g_args) in layers_and_args:
--> 149             x = x + f(x, **f_args)
    150             x = x + g(x, **g_args)
    151         return x

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
     65     def forward(self, x, **kwargs):
     66         x = self.norm(x)
---> 67         return self.fn(x, **kwargs)
     68 
     69 class Chunk(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, input_mask, context, context_mask, **kwargs)
    258 
    259         if has_local:
--> 260             local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
    261             out.append(local_out)
    262 

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/local_attention/local_attention.py in forward(self, q, k, v, input_mask)
    136         if input_mask is not None:
    137             h = b // input_mask.shape[0]
--> 138             input_mask = input_mask.reshape(-1, windows, window_size)
    139             mq = mk = input_mask
    140             mk = look_around(mk, pad_value=False, **look_around_kwargs)

RuntimeError: shape '[-1, 64, 128]' is invalid for input of size 4201983

I got rid of that error by doing this:

diff --git a/linear_attention_transformer/autopadder.py b/linear_attention_transformer/autopadder.py
index dd84663..d10927d 100644
--- a/linear_attention_transformer/autopadder.py
+++ b/linear_attention_transformer/autopadder.py
@@ -48,7 +48,10 @@ class Autopadder(nn.Module):
         x, padding = pad_to_multiple(x, self.pad_to, dim=self.pad_dim, pad_left=self.pad_left)

         if padding != 0:
-            offset = (0, padding) if not self.pad_left else (padding, 0)
+            if self.pad_dim == -1:
+                offset = (0, padding) if not self.pad_left else (padding, 0)
+            else:
+                offset = (0, 0, 0, padding) if not self.pad_left else (0, 0, padding, 0)
             new_mask = F.pad(input_mask, offset, value=False)
             kwargs.update(input_mask=new_mask)

But then I hit this:

import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder

model =  Autopadder(LinearAttentionTransformer(
    dim = 128,
    heads = 4,
    depth = 1,
    max_seq_len = 256,
    n_local_attn_heads = 4
)).cuda()

x = torch.randn(1, 255, 512).cuda()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-3a1475fdc86c> in <module>
     12 
     13 x = torch.randn(1, 255, 512).cuda()
---> 14 model(x) # (1, 8191, 512)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
     56             kwargs.update(input_mask=new_mask)
     57 
---> 58         out = self.net(x, **kwargs)
     59 
     60         output_slice = slice(0, t) if not self.pad_left else slice(padding, None)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
    354 
    355     def forward(self, x, **kwargs):
--> 356         return self.layers(x, **kwargs)
    357 
    358 class LinearAttentionTransformerLM(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
    147 
    148         for (f, g), (f_args, g_args) in layers_and_args:
--> 149             x = x + f(x, **f_args)
    150             x = x + g(x, **g_args)
    151         return x

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
     64         self.norm = nn.LayerNorm(dim)
     65     def forward(self, x, **kwargs):
---> 66         x = self.norm(x)
     67         return self.fn(x, **kwargs)
     68 

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/normalization.py in forward(self, input)
    168     def forward(self, input: Tensor) -> Tensor:
    169         return F.layer_norm(
--> 170             input, self.normalized_shape, self.weight, self.bias, self.eps)
    171 
    172     def extra_repr(self) -> Tensor:

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/functional.py in layer_norm(input, normalized_shape, weight, bias, eps)
   2047     """
   2048     return torch.layer_norm(input, normalized_shape, weight, bias, eps,
-> 2049                             torch.backends.cudnn.enabled)
   2050 
   2051 

RuntimeError: Given normalized_shape=[128], expected input with shape [*, 128], but got input of size[1, 256, 512]

Thanks for this great repo!

causal = True

Naive question!
causal = True , is this used to create a mask that trims/clips the diagonal right half of the attention matrix?

Thank you!

dalle

This is more of a question than an issue. Can this in theory be integrated with DALLE-pytorch to speed it up by a lot?

I want to try to do it, but also want to ask the author if there are reasons why this would not work first. Thank you in advance.

seq2seq decoder ids

Hi there, thanks for all the work.

I am wondering how one should greedily sample the model in seq2seq. For example, lets say I want to do some MT. Is this how it works?

  1. Get context (output of encoder) using src
  2. Set tgt to be a start sequence token
  3. Get logits using context and tgt
  4. Find argmax of logits
  5. Add the argmax of logits token to the end of tgt
  6. Get new logits using context, and new tgt sequence, and repeat

So, tgt is essentially the decoder_ids (as in HuggingFace)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.