Git Product home page Git Product logo

mamba.py's Introduction

My favorite projects :

ReadMe Card

ReadMe Card

ReadMe Card

mamba.py's People

Contributors

aliyoussef97 avatar alxndrtl avatar beebopkim avatar michal1000w avatar zeng-wch avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

mamba.py's Issues

Parallel pscan

I am looking at the file pscan.py, and see:

class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # modifies X in place by doing a parallel scan.
        # more formally, X will be populated by these values :
        # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
        # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

        # only supports L that is a power of two (mainly for a clearer code)

        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps-2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)

            Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
            Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

            Aa = Aa[:, :, :, 1]
            Xa = Xa[:, :, :, 1]
          ...

This function is not using MLX. Where exactly is the parallelism? I must not understand because the number of operations is O(T) and is fully sequential. I am sure I am missing something simple.

 Gordon

delta question

Hello,

Thank you for the amazing work!

I had a question regadring the latest commit. If the following dot product was added to the ssm function, is there a a specific reason why it was not in the ssm_step function as well?

Moreover the softplus in the ssm function uses the bias of dt_proj, while in ssm_step uses the previous implementation delta = F.softplus(self.dt_proj(delta))

Lastly, should D here have a _no_weight_decay simlar to A_log?


Edit: If the new modifications to delta was not added to ssm_step intentionally, does that mean that during inference I have to use the step and can not use forward as well?


Edit2: If in the forward function, I do a birectional forward (as in Vision Mamba), such as:

output = self.mamba_block(self.norm(x))
x_flip = x.flip([1])
output_flip = self.mamba_block(self.norm(x_flip))
output += output_flip
return output + x

However, in the step, a each directional forward will return its own cache, which I am not sure how to handle exactly as unfortunately, I do not fully understand the cache mechanisim (h,input).

Apologies for the long question.

Thank you!

MLX inference error with BFloat16

from commit hash 6a49341:

What I have done was executing generate.py for a mamba fine-tuned model - kuotient/mamba-ko-2.8b, and below error was happened. How can I deal with this error?

(venv_mamba_py) ******@Mac-Studio-2022-01 scripts % python generate.py --prompt="Mamba is a type of" --hf_model_name="kuotient/mamba-ko-2.8b" --n_tokens=100

Traceback (most recent call last):
  File "/Users/******/test/mamba.py/mlx/scripts/generate.py", line 31, in <module>
    model = MambaLM.from_pretrained(args.hf_model_name)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/mamba_lm_mlx.py", line 150, in from_pretrained
    mlx_state_dict = map_mambassm_torch_to_mlx(state_dict)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/utils.py", line 53, in map_mambassm_torch_to_mlx
    return map_mambapy_torch_to_mlx(new_state_dict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/utils.py", line 37, in map_mambapy_torch_to_mlx
    new_state_dict[key] = value.numpy()
                          ^^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16

My environments are:

MLX memory usage at inference

The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)

This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step.

However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).

Files concerned :

  • mamba_mlx.py (step functions)
  • misc.py

The depthwise conv implemented in misc.py seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.

Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model=2560)

Question on using a sequence length > the max length I can hold in a batch due to memory usage for training

For running extremely large sequence lengths, I can break up the sequence among batches and use the inference step to save the hidden states between passes. However, step does not work during training as there seems to be a problem with overwriting the variables before loss.backward() can be run.

How might you modify the forward pass to allow running the same parallel scan as used during training, but connect previous hidden states from other batches?

Why use element-wise multiplication rather than matrix multiplication in the function `selective_scan_seq`

Hello. In the function selective_scan_seq, there are two points that I am confused:

  • BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
  • h = deltaA[:, t] * h + BX[:, t]
    These two lines of code seem to be element-wise multiplication.

However, in the paper, the equation is
$$h_t = \bar{A} h_{t-1} + \bar{B} x_{t}$$
Both terms in the right side of the euation is performed in matrix multiplication.

I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?

Error in the Mamba Block forward function?

I am going through the code line by line and adding additional comments with shape information. Inside the MambaBlock's forward function, duplicated below:

    def forward(self, x):
        # input x : (B, L, D)
        # return : (B, L, D)

        _, L, _ = x.shape

        # in_proj: D -->  2*ED (two or three branches)
        xz = self.in_proj(x) # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)

        # x branch
        # rearrange(x, "... ED L -> ... L ED")
        x = x.transpose(1, 2) # (B, ED, L)
        # What is the point of convolution?
        x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
        # rearrange(x, "... L ED -> ... ED L")
        x = x.transpose(1, 2) # (B, L, ED)

        #  x --> conv1d --> silu --> ssm  --> y ---> output (y*z) --> (B, L, ED)
        #                             ^          /
        #  z -------------> silu ---- |---------/

        x = F.silu(x)
        y = self.ssm(x, z)   # (B, L, ED), (B, L, ED) --> (B, L, ED)
        print(f"Return from self.ssm, {y.shape=}")

        # GE: why the early exit if using CUDA?
        if self.config.use_cuda:     ######<<<<<<<<
            output = self.out_proj(y) # (B, L, D)
            return output

        # z branch
        z = F.silu(z)   #  (B, L, ED)

        # Why multiply y * z?
        output = y * z
        #              ED -> D
        output = self.out_proj(output) # (B, L, D)
        return output  # (B, L, D)

there is the conditional:
if self.config.use_cuda:
Depending on whether cuda is used or not, the output of the forward() method is different. If cuda is not used, there an additional F.silu applied to z followed by output=y*z.

If this is not an error, could you please explain the reason for this apparent discrepancy? Thanks.

Gordon

MuP

Hello, and thanks for sharing your code. I stumbled on your repo while looking for how to implement mup for mamba. It seems like you implemented mup without scaling any attn-like matrices. Does that mean that mup work withs mamba out of the box as long as the right initializations (from mup package for example) are implemented?

Thanks for your help.

Pscan documentation

Hi, I really liked this project.
I was hoping if you could also finish pscan documentation in notebook. It already gave me a lot of clarity. It will be helpful further.

A fresh can't start the model

I'm sorry that I'm a fresh, so I can't get the model running smoothly.
When I start the 'example_llm.ipynb', I get the error like this
图片
and then I tried to copy the same code to .py file, I get the error like
图片

the same question had been met in 'mamba-minimal', I don't know how to solve it.Bless for help, thank you!

Can I translate your PScan in Jax?

Thanks a lot for your work and minimal implementation!

For work, I need to implement some models to benchmark, and I really want to include mamba-related models.

To do so, I created Jimmy (for Jax Image Model :) really not prod-ready at all yet) https://github.com/clementpoiret/jimmy
But porting the CUDA code into something that can be compiled by XLA is just beyond what I can do rn.

With credits of course, may I port your pscan to Jimmy?

Thanks!

Onnx export for the inference

Hi Alex,

I have tried to generate the onnx file for the inference,
as follows in generate function.

I called onnx export inside generate function here: https://github.com/alxndrTL/mamba.py/blob/main/mamba_lm.py#L134

as follows:

torch.onnx.export(model, inputids, "mamba.onnx", opset_version=12)

It is throwing me an error

TypeError: MambaLM.forward() missing 1 required positional argument: 'tokens'.

Any idea how can I generate onnx file? Is there a better way of generating onnx file for inference?

Up sweep in parallel scan

Thank you for your great work.

In your parallel scan, when Xa.size(2) == 2 or Xa.size(2) == 1, why you skip the up-sweep operation ?
(line 64 in file)

Is it related to the fact this isn't a static tree but a representation of the evolution of our tensor in memory in your document ?

About the speed test

Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):

for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
    ...
   if constexpr (kIsVariableB) {
                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
   }
}

which shows a for loop with related to state_idx that reads from HBM to shared memory.

Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.

    device = torch.device("cuda")
    dtype = torch.float32
    B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
    xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Ai = torch.randn((G * D, N), device=device, dtype=dtype)
    Di = torch.randn((G * D), device=device, dtype=dtype)
    dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
    Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
    tpb = torch.randn((G * D), device=device, dtype=dtype)

    Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
    Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
    Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)

    import time
    tim0 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim1 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim2 = time.time()
    print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
    time.sleep(100000)

So what did I miss?

Discretization step seems to be different from the paper

Hi,

Firstly, thanks for making this repo. I found it very useful in understanding the scan algorithm. However, discretization in this repo seems to be different from the Eq 4 of the paper. Do you have any comments on this? Also, I wonder why the original paper needs the discretization step in the first place since it is possible to make the discrete versions of A and B conditioned on the input directly. I imagine that it must have something to do with the initialization, but I am not sure.

deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

image

Please rectify paths in example_e2e_training.ipynb

Please rectify the path

from example_src.tinyhome import TinyHomeEngineV1, print_grid, print_act
from example_src.buffer import ReplayBuffer

to

from examples.tinyhome import TinyHomeEngineV1, print_grid, print_act
from examples.buffer import ReplayBuffer

Partial batches in Mamba_lm

I found an unnecessary check in the In the get_data() function of mamba_lm, you clearly ensure that there are no partial batches.

In the training loop (below the get_batch() function), you check whether there are partial batches, and skip in case there is. This check is unnecessary.

Here is the relevant code. First get_batch():

    def get_batch(
        data: Float[T, " B Examples"], seq_len: int, idx: int
    ) -> tuple[Float[T, "B SeqLen"], Float[T, "B SeqLen"]]:
        """Retrieve a single batch"""

        src = data[:, idx : idx + seq_len]  # noqa: E203
        target = data[:, idx + 1 : idx + seq_len + 1]  # noqa: E203
        return src, target

where the batch size is always the same. There are no partial batches. Below is the extraneous check:

            # If the batch is not complete - skip
            ###  The batch is always complete
            if logits.view(-1, logits.size(-1)).shape[0] != output.view(-1).shape[0]:   # <<<< UNNECESSARY
                print("skip")
            else:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)
                avg_loss += loss.item()

                optim.zero_grad()
                loss.backward()

Segmentation fault with MLX

Segmentation fault: 11 while inferencing mamba with mlx

https://github.com/alxndrTL/mamba.py/tree/main/mlx

python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100

on an Apple M1 Pro

I found that the line that cause the error is

mlx_weights = torch.zeros(channels, kernel_size, channels)

in functiontorch_to_mlx_depthwise_weights

but I don't know how to fix it

Cuda Version

Hi, great work!

How to enable cuda because I found:

if self.config.use_cuda:
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # I did not find mamba_ssm in this repo
self.selective_scan_cuda = selective_scan_fn
except ImportError:
print("Failed to import mamba_ssm. Falling back to mamba.py.")
self.config.use_cuda = False

Looking forward to your response.

Thanks !

huge huge memory usage!!

i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??

great thanks!

Possible SSM-Transformers implementation?

Hey! Awesome work on this project! I know it's not technically vanilla Mamba but I've been trying to convert the new SSM-Transformers Jamba into MLX for more efficient training and usability but am having a difficult time. My specialty is in the training/datasets world and not the strongest in the core math behind the model architectures beyond the basic implementations.

Would somebody know of an easier way to get Jamba converted into MLX? I truly think Jamba has A LOT to offer and could do some awesome stuff in the MLX format and for local model training with Mac

I've provided the modeling script released by AI21 for quick reference. Is this feasible or just way too complicated at the moment?

modeling_jamba.txt

Default implementation of Jamba

Here is a section of code in JambaLM

class Jamba(nn.Module):
    def __init__(self, config: JambaLMConfig):
        super().__init__()

        self.config = config

        # init each model layer, decide if it's mamba/attention and has experts or not
        decoder_layers = []
        for i in range(config.n_layers):
            is_attn = (
                True
                if (i - self.config.attn_layer_offset) % self.config.attn_layer_period
                == 0
                else False
            )
            is_expert = (
                True
                if (i - self.config.expert_layer_offset)
                % self.config.expert_layer_period
                == 0
                else False
            )

You'll notice that the structure of is_attn and is_expert is identical. Furthermore, in the default configuration provided, is_attn=is_expert=False, and they are both true at the same time. As a result, all the layers in this default Jamba architecture are all the same. Of course I can change that, but this is surely not intended given that this code is didactic. Thanks.

Can we get an explicit license?

Would it be possible to put an explicit OSS license on the codebase just to remove approval burden when experimenting based on this code? I want to prototype a bit on laptop to get the plumbing right before moving to cloud GPUs to actually train and this looks like the best CPU friendly implementation I have found for that purpose

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.