Git Product home page Git Product logo

jamba's Introduction

Multi-Modality

Jamba

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

install

$ pip install jamba

usage

# Import the torch library, which provides tools for machine learning
import torch

# Import the Jamba model from the jamba.model module
from jamba.model import Jamba

# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))

# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
    dim=512,
    depth=6,
    num_tokens=100,
    d_state=256,
    d_conv=128,
    heads=8,
    num_experts=8,
    num_experts_per_token=2,
)

# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)

# Print the model's predictions
print(output)

Train

python3 train.py

License

MIT

jamba's People

Contributors

kyegomez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

jamba's Issues

zeta

May I ask if the zeta of the model file in the jamba folder is a package or an unknown file

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Urgent Query

I am confused about the MoE layer in Jamba block. There are many versions of MoE. The paper has not defined in detail the mathematics or diagrams to understand the expert system. Can you please guide or share exact paper which has been followed in jamba?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] The example on the README is not working

Describe the bug

I tried to run the example code just out of curiosity in a clean env and could not reproduce the output.

To Reproduce

Here in the colab too: https://colab.research.google.com/drive/1OnyI7WfXUkqXqscz8QiFUErDHA2kEfm5?usp=sharing

Steps to reproduce the behavior:

  1. Create a fresh conda, pipenv, or colab env.
  2. Go to the README file.
  3. Install jamba
  4. Run the example:
# Import the torch library, which provides tools for machine learning
import torch

# Import the Jamba model from the jamba.model module
from jamba.model import Jamba

# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))

# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
    dim=512,
    depth=6,
    num_tokens=100,
    d_state=256,
    d_conv=128,
    heads=8,
    num_experts=8,
    num_experts_per_token=2,
)

# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)

# Print the model's predictions
print(output)
  1. See error tracker:
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-1-5e7c358db527>](https://localhost:8080/#) in <cell line: 5>()
      3 
      4 # Import the Jamba model from the jamba.model module
----> 5 from jamba.model import Jamba
      6 
      7 # Create a tensor of random integers between 0 and 100, with shape (1, 100)

23 frames

[/usr/local/lib/python3.10/dist-packages/jamba/__init__.py](https://localhost:8080/#) in <module>
----> 1 from jamba.model import JambaBlock, Jamba
      2 
      3 __all__ = ["JambaBlock", "Jamba"]

[/usr/local/lib/python3.10/dist-packages/jamba/model.py](https://localhost:8080/#) in <module>
      1 from torch import Tensor, nn
----> 2 from zeta import MambaBlock
      3 from zeta.nn import FeedForward
      4 from zeta import MultiQueryAttention
      5 from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm

[/usr/local/lib/python3.10/dist-packages/zeta/__init__.py](https://localhost:8080/#) in <module>
     26 logger.addFilter(f)
     27 
---> 28 from zeta.nn import *
     29 from zeta.models import *
     30 from zeta.utils import *

[/usr/local/lib/python3.10/dist-packages/zeta/nn/__init__.py](https://localhost:8080/#) in <module>
----> 1 from zeta.nn.attention import *
      2 from zeta.nn.embeddings import *
      3 from zeta.nn.modules import *
      4 from zeta.nn.biases import *

[/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/__init__.py](https://localhost:8080/#) in <module>
     12 # from zeta.nn.attention.mgqa import MGQA
     13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
---> 14 from zeta.nn.attention.mixture_attention import (
     15     MixtureOfAttention,
     16     MixtureOfAutoregressiveAttention,

[/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/mixture_attention.py](https://localhost:8080/#) in <module>
      6 from typing import Tuple, Optional
      7 from einops import rearrange, repeat, reduce
----> 8 from zeta.models.vit import exists
      9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb
     10 

[/usr/local/lib/python3.10/dist-packages/zeta/models/__init__.py](https://localhost:8080/#) in <module>
      1 # Copyright (c) 2022 Agora
      2 # Licensed under The MIT License [see LICENSE for details]
----> 3 from zeta.models.andromeda import Andromeda
      4 from zeta.models.base import BaseModel
      5 from zeta.models.gpt4 import GPT4, GPT4MultiModal

[/usr/local/lib/python3.10/dist-packages/zeta/models/andromeda.py](https://localhost:8080/#) in <module>
      2 from torch.nn import Module
      3 
----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
      5 from zeta.structs.transformer import (
      6     Decoder,

[/usr/local/lib/python3.10/dist-packages/zeta/structs/__init__.py](https://localhost:8080/#) in <module>
      2 from zeta.structs.encoder_decoder import EncoderDecoder
      3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer
----> 4 from zeta.structs.local_transformer import LocalTransformer
      5 from zeta.structs.parallel_transformer import ParallelTransformerBlock
      6 from zeta.structs.transformer import (

[/usr/local/lib/python3.10/dist-packages/zeta/structs/local_transformer.py](https://localhost:8080/#) in <module>
      6 from zeta.nn.attention.local_attention_mha import LocalMHA
      7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias
----> 8 from zeta.nn.modules import feedforward_network
      9 from zeta.utils.main import eval_decorator, exists, top_k
     10 

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/__init__.py](https://localhost:8080/#) in <module>
     45 from zeta.nn.modules.s4 import s4d_kernel
     46 from zeta.nn.modules.h3 import H3Layer
---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer
     48 from zeta.nn.modules.leaky_relu import LeakyRELU
     49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in <module>
    143     1, 512, 32, 32
    144 )  # Batch size of 1, 512 channels, 32x32 image
--> 145 output = mlp_mixer(example_input)
    146 print(
    147     output.shape

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
    123         x = rearrange(x, "n c h w -> n (h w) c")
    124         for mixer_block in self.mixer_blocks:
--> 125             x = mixer_block(x)
    126         x = self.pred_head_layernorm(x)
    127         x = x.mean(dim=1)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
     61         y = self.norm1(x)
     62         y = rearrange(y, "n c t -> n t c")
---> 63         y = self.tokens_mlp(y)
     64         y = rearrange(y, "n t c -> n c t")
     65         x = x + y

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
     28             torch.Tensor: _description_
     29         """
---> 30         y = self.dense1(x)
     31         y = F.gelu(y)
     32         return self.dense2(y)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py](https://localhost:8080/#) in forward(self, input)
    114 
    115     def forward(self, input: Tensor) -> Tensor:
--> 116         return F.linear(input, self.weight, self.bias)
    117 
    118     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.