Git Product home page Git Product logo

Comments (6)

Zymrael avatar Zymrael commented on August 23, 2024

Memory scaling is also approximately linear like FlashAttention, OOM at sequence length 1500 sounds strange. What is your exact setup? (Hyena hyperparameters, attention hyperparameter,s sequence length, batch size).

from safari.

mgaido91 avatar mgaido91 commented on August 23, 2024

Yes, sorry. Here is a brief description of the setup. I am replacing the self-attention in a Conformer model with Hyena operator. The config for Hyena is the following:

HyenaOperator(
          (dropout): Dropout(p=0.1, inplace=False)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
          (in_proj): Linear(in_features=512, out_features=1536, bias=True)
          (short_filter): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(1,), groups=1536)
          (filter_fn): HyenaFilter(
            (dropout): Dropout(p=0.1, inplace=False)
            (pos_emb): ComplexExponentialPositionalEmbedding()
            (implicit_filter): Sequential(
              (0): Linear(in_features=3, out_features=64, bias=True)
              (1): Sin()
              (2): Linear(in_features=64, out_features=64, bias=True)
              (3): Sin()
              (4): Linear(in_features=64, out_features=64, bias=True)
              (5): Sin()
              (6): Linear(in_features=64, out_features=64, bias=False)
            )
            (modulation): ExponentialModulation()
          )
        )

with

            order=2,
            filter_order=64,
            num_heads=1,
            inner_factor=1,
            num_blocks=1,
            outer_mixing=False,

With this setting, I am able to train a Conformer with self-attention using sequence length 1500 and batch size 40k tokens with a 40GB GPU RAM.

Also, I made a mistake in the previous description, as replacing the self-attention with Hyena operator I get the following conditions:

  • With maximum sequence length 6k and batch size 40k tokens it goes OOM with a 40GB GPU RAM;
  • With maximum sequence length 1500 and batch size 40k tokens it works with a 40GB GPU RAM, but also an attention-based model works in this condition;
  • With maximum sequence length 6k and batch size 20k tokens, works without OOM, while in this case an attention-based model goes OOM.

So it does seem to reduce the memory usage with respect to an attention based model, but the gains in range of lengths seem limited (compared to a log linear vs quadratic complexity). So I am wondering if you have some data you can share on the memory occupation that should be expected to understand when the benefits of Hyena become more visible. Thanks.

from safari.

Zymrael avatar Zymrael commented on August 23, 2024

Is attention using FlashAttention?

from safari.

mgaido91 avatar mgaido91 commented on August 23, 2024

No it is a plain attention implementation.

from safari.

DanFu09 avatar DanFu09 commented on August 23, 2024

What is your model size? You may just be seeing the memory footprint of number of tokens + model weights + activations.

from safari.

mgaido91 avatar mgaido91 commented on August 23, 2024

The model size is 1.3GB, indeed, if I set the batch size to 10k with sequence length 6k, I see only 22GB of the GPU RAM are used.

from safari.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.