Git Product home page Git Product logo

osu-starlab / leapformer Goto Github PK

View Code? Open in Web Editor NEW
2.0 1.0 0.0 20.43 MB

[ICML 2024] Official implementation of "LeaPformer: Enabling Linear Transformers for Autoregressive and Simultaneous Tasks via Learned Proportions."

License: MIT License

Python 95.10% Shell 3.04% Makefile 0.01% Batchfile 0.01% C++ 0.50% C 0.02% Cuda 0.99% Perl 0.04% Cython 0.21% Lua 0.07%
efficiency language-modeling linear-attention long-range-arena simultaneous-translation transformer-architecture

leapformer's Introduction

LeaPformer

This repository contains the official implementation of "LeaPformer: Enabling Linear Transformers for Autoregressive and Simultaneous Tasks via Learned Proportions," the preprint for which can be found here. LeaPformers are, fundamentally, a novel modification of specific re-weighting functions for linear attention mechanisms that can enable them for a wider range of tasks. Due to improved flexibility, oftentimes LeaPformers are also more accurate than alternatives with only a small amount of added latency.

Set-up for various parts of this repo are somewhat separated, as they were occasionally validated in different environments (i.e. the environment for LRA tests was not necessarily identical to the environment for LM or SimulST due to some compatibility issues). Instructions for set=up are provided in pytorch-lra and fairseq-leapformer.

LeaPformers on the Long-Range Arena (LRA) Benchmark

Our slightly modified version of the Skyformer PyTorch LRA benchmark can be found in pytorch-lra, containing several additional linear attention mechanisms compared to the original implementation. Details for running the LRA benchmark are also provided there, including some example scripts.

As a note, this particular set-up focuses on extremely small models, allowing for tests with quadratic, softmax attention on long-sequence tasks for medium-to-low quality hardware.

LeaPformers on Autoregressive Language Modeling

We validated LeaPformers on small-scale autoregressive language modeling (i.e. around 140M parameters) via an older, private fork of Fairseq, to be provided in fairseq-leapformer (still being cleaned up, initial implementation was ad-hoc). Scripts are available in fairseq-leapformer/leapformer-scripts/lm and, should one want to use a more updated version of Fairseq, it can be found here.

Cleaning up. Will be finished soon.

LeaPformers on S2T Simultaneous Translation (SimulST)

Similarly, we validated LeaPformers on SimulST on that same Fairseq fork. Unlike the autoregressive language modeling example, changes for SimulST are also placed in fairseq-leapformer/examples/speech_to_text/simultaneous_translation/agents and fairseq-leapformer/examples/simultaneous_translation, where some custom encoder-decoder masking occurs and the SimulEval agent is modified. Scripts are available in fairseq-leapformer/leapformer-scripts/simulst.

Cleaning up. Will be finished soon.

What about more performant causal training/inference?

As mentioned in this work, our implementations (especially causal ones) are not optimized. A number of works have demonstrated the importance of constructing hardware-aware implementations to maximize performance. Obvious next steps here would be constructing a Triton-based LeaPformer implementation (ร  la Flash Linear Attention or FLA). In fact, integration with FLA is likely simple, especially for applications that are just decoder-based (e.g. autoregressive language modeling), requiring transforms being applied to the query and key before calling FLA specialized kernels.

Other future steps for LeaPformers?

LeaPformers were originally conceived back in mid-2023, and a number of interesting works have been published since then containing elements which can be applied towards LeaPformers. For example:

  1. There are no RNN-like gating mechanisms in this work, despite concurrent work like Gated Linear Attention (GLA) using it to great effect.
  2. Moreover, several works have skipped the time-dependent normalization term in linear attention, either favoring normalization blocks (e.g. LayerNorm or GroupNorm, seen in papers here and here), similarly seen in GLA. In our experiments, this made no real difference but might at scale.
  3. Finally, the scale of the experiments in this work are ultimately small for modern applications, where it's very attractive to attempt to experiment at scale (i.e. around 300M+ minimum to several billion parameters).

Reference

If you found our work insightful or useful, please consider citing us as:

@inproceedings{
      agostinelli2024leapformer,
      title={LeaPformer: Enabling Linear Transformers for Autoregressive and Simultaneous Tasks via Learned Proportions},
      author={Victor Agostinelli and Sanghyun Hong and Lizhong Chen},
      booktitle={Forty-first International Conference on Machine Learning},
      year={2024},
      url={https://openreview.net/forum?id=XhH1OKLANY}
}

leapformer's People

Contributors

myleott avatar alexeib avatar penneyd avatar cndn avatar tangyuq avatar kahne avatar louismartin avatar agostinv avatar theweiho avatar sshleifer avatar liezl200 avatar huihuifan avatar edunov avatar freewym avatar xu-song avatar xutaima avatar maigoakisame avatar pipibjc avatar liuchen9494 avatar jhcross avatar joshim5 avatar meaffel avatar lematt1991 avatar multipath avatar kartikayk avatar erip avatar mortimerp9 avatar jma127 avatar skritika avatar sravyapopuri388 avatar

Stargazers

Swall0w avatar Lakshya Bakshi avatar

Watchers

Kostas Georgiou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.