Git Product home page Git Product logo

mamba's Introduction

Mamba Model in Easy Pytorch ๐Ÿ

This repo contains the unofficial PyTorch implementation of the Mamba model as introduced in Gu & Dao (2023). One can find the official (CUDA) implementation here or the nice alxndrTL alternative Python implementation. This repo is developed mainly for didactic purposes to spell out the details of a State Space Models.

The main contribution of this repo is the implementation of the Mamba architecture (upon which a basic LLM is built, see ๐Ÿ“‚ mamba.llm) using the Lightning Framework which abstract away the difficulty of training on distributed settings.

Usage

The basic usage is to instantiate a Mamba model and run it on the input of choice. The model expects the input tensor to have shape (batch_size, seq_len, d_input) and outputs a similarly shaped tensor.

from mamba import Mamba

d_input = 16
seq_len = 32
batch_size = 4

model = Mamba(
  num_layers = 6,    # Number of layers of the full model
  d_input = d_input, # Dimension of each vector in the input sequence (i.e. token size)
  d_model = 64,      # Dimension of the visible state space
  d_state = 64,      # Dimension of the latent hidden states
  d_discr = 16,      # Rank of the discretization matrix ฮ”
  ker_size = 4,      # Kernel size of the convolution in the MambaBlock
  parallel = False,  # Whether to use the sequential or the parallel implementation
)

# Mockup input for example purposes
inp = torch.randn(batch_size, seq_len, d_input)

# Compute the output using the Mamba architecture
out = model(inp) # (batch_size, seq_len, d_input)

This repo also offers a didactic implementation of an LLM built using Lightning which unlocks easy training on multi-gpus. To use it one can simply run the following example:

from lightning import Trainer
from transformers import AutoTokenizer

from mamba import MambaLLM
from mamba.tiny import TinyStoriesLightning

config = ... # path to YAML configuration file

# Load an off-the-shelf tokenizer from HF
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')

# Load the Mamba model from a config file
model = MambaLLM.from_config(config, key='mamba')

# Load the dataset
dataset = TinyStoriesLightning.from_config(
    config,
    tokenizer,
    key='dataset'
)

trainer = Trainer(
  max_epochs  = 500,
  accelerator = 'gpu',
  devices     = 4, # Piece of cake multi-gpu support!
  strategy    = 'ddp_find_unused_parameters_false',
)

# Train the model
trainer.fit(model, dataset)

Alternatively, one can also run the training script train.py directly which expects the configuration file path and accepts all the Trainer arguments.

python train.py --config <path_to_config_file>\
  --max_epochs 500\
  --accelerator gpu\
  --devices 4

A cool feature of MambaLLM implementation is the lazy (batched-) inference implemented via a generator. One can thus print tokens on screen as they are streamed by the model, no need to wait for the whole inference to finish! A mock-up script would look like the following.

from mamba import MambaLLM
from transformers import AutoTokenizer

# Get an off-the-shelf tokenizer
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')

tokenizer.add_special_tokens({'pad_token': '<|pad|>'})

# Parameters for the LLM
vocab_size = tokenizer.vocab_size + 1
num_layers = 6
d_input = 16
d_model = 64
d_state = 64
d_discr = 16
ker_size = 4
parallel = True

model = MambaLLM(
    vocab_size = vocab_size,
    num_layers = num_layers,
    d_input = d_input,
    d_model = d_model,
    d_state = d_state,
    d_discr = d_discr,
    ker_size = ker_size,
    parallel = parallel
)

# Parameters for the inference
token_lim = 16
use_top_k = 50
temperature = 0.7

# Generate text
stream = model.generate(
  # We can provide more than one prompt!
  prompt=[
      'Once upon a time',
      'In a galaxy far far away',
  ],
  tokenizer=tokenizer,
  token_lim=token_lim,
  use_top_k=use_top_k,
  temperature=temperature,
)

for token in stream:
    # Each token is a dictionary indexed by the
    # batch-id and contains the produced string
    # as value, so we can print the first batch as:
    print(token[0], end='')

Roadmap

  • Put all the essential pieces together
  • Add functioning parallel implementation (p-scan) (๐Ÿ™๐Ÿป @ Zeta Project & alxndrTL)
  • Add functioning training script (Lightning)
  • Show some results

Requirements

This repo build on top of some cool libraries, one can easily install the dependencies via pip install -r requirements.txt

einops==0.7.0
lightning==2.2.2
PyYAML==6.0
PyYAML==6.0.1
torch==2.0.1
transformers==4.40.0

Citations

Some of the code is inspired by the great repo LLM From Scratch that was used as a guide to elucidate several doubts along the way.

@article{gu2023mamba,
  title={Mamba: Linear-time sequence modeling with selective state spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}

mamba's People

Contributors

myscience avatar

Stargazers

Joffrey avatar Akhil Songa avatar Lev Novitskiy avatar  avatar  avatar  avatar soyoung avatar Harim Kang avatar Artur avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

styalai

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.