Git Product home page Git Product logo

lightning-attention's People

Contributors

doraemonzzz avatar weigao266 avatar xuyangshen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightning-attention's Issues

Cannot run the triton kernels

Thanks for this repo, I'm pretty excited to test this out.

I drop-in replaced attention from lightning-attention in one of my projects and got the following:

RuntimeError: PassManager::run failed                                                                                                                                                                                                                                                                                                                                                                                
Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                                                   
  File "/opt/ml/code/open_lm/main.py", line 873, in <module>                                                                                                                                                                                                                                                                                                                                                         
main(sys.argv[1:])                                                                                                                                                                                                                                                                                                                                                                                                   
File "/opt/ml/code/open_lm/main.py", line 774, in main                                                                                                                                                                                                                                                                                                                                                               
success, global_step = train_one_epoch(                                                                                                                                                                                                                                                                                                                                                                              
  File "/opt/ml/code/open_lm/train.py", line 267, in train_one_epoch                                                                                                                                                                                                                                                                                                                                                 
backward(local_loss, scaler)                                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/ml/code/open_lm/train.py", line 92, in backward                                                                                                                                                                                                                                                                                                                                                         
total_loss.backward()                                                                                                                                                                                     
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward                                                                                                                  
torch.autograd.backward(                                                                                                                                                                                  
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward                                                                                                        
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                                                                                            
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply                                                                                                           
return user_fn(self, *args)                                                                                                                                                                               
  File "/lightning-attention/lightning_attn/ops/triton/lightning_attn2.py", line 462, in backward                                                                                                         
_bwd_intra_kernel[grid](                                                                                                                                                                                  
  File "<string>", line 63, in _bwd_intra_kernel                                                                                                                                                          
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile                                                                                                          
next_module = compile_kernel(module)                                                                                                                                                                      
  File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 383, in <lambda>                                                                                                       
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))                                                                                                                              
  File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir                                                                                                  
pm.run(mod)                                                                              

So I tried to simply run pytest tests/ops/test_lightning2.py
And got only failures (it is weird that there is an assert False statement in there...)
And the more worrisome result is that the errors are quite large...

tests/ops/test_lightning2.py FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF        [100%]

tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-256-128-64] tensor(0.1543, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.1650, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-512-128-64] tensor(0.2393, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2520, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-1024-128-64] tensor(0.3555, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3770, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-128-64] tensor(0.5117, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-4096-128-64] tensor(0.7344, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7773, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-8192-128-64] tensor(1.0391, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-32-64] tensor(0.2578, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-64-64] tensor(0.3633, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3848, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-12-2048-128-64] tensor(0.5234, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-16-2048-128-64] tensor(0.6719, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7109, device='cuda:0', dtype=torch.bfloat16)
FAILED

Does using lightning-attention need retraining?

Hello, I have replaced the normal self-attention calculation in my own model with lightning attention, without any additional operations, but I found that the model is poorly reasoned and tested.

Therefore, I would like to ask, just replacing the normal self-attention calculation with lightning attention, does this approach will have any effect on the model accuracy? Do I need to retrain my model? thank you very much!

And I find another issue(https://github.com/OpenNLPLab/lightning-attention/issues/10#issuecomment-1986779377) that you said lightning attention has no parameters, so it maybe should not have any effect on the model accuracy just like flash attention and it doesn't need to be trained?

Here is the code I used to calculate the attention forward process originally:

def forward(self, x: torch.Tensor) -> torch.Tensor:

        B, H, W, _ = x.shape

        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        if self.use_rel_pos:
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)

        x = self.proj(x)

        return x

Here is my modified code for computing the attention forward process using LIGHTNING ATTENTION:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape

        # qkv with shape (B, H * W, 3, num_heads, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # unbind the 3 tensors

        # build slope tensor for Lightning Attention
        slope_tensor = _build_slope_tensor(self.num_heads).to(x.device).to(torch.float32)

        # compute attention using Lightning Attention
        attn = lightning_attn_func(q, k, v, slope_tensor)

        # reshape attention output
        attn = attn.view(B, H, W, -1)

        # final projection
        x = self.proj(attn)

        return x

No module named 'lightning_attn'

Great work!

After intsalling the module using 'pip install lightning_attn', it still raises an error of 'No module named 'lightning_attn'.

But it indeed has existed in my env.

1705653265232

1705653370009

Weird...

Needing help, thank you in advance!

TypeError("unhashable type: 'tensor'")

Traceback (most recent call last):
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
generator.visit(fn.parse())
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 293, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 279, in generic_visit
self.visit(item)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement
ret_type = self.visit(stmt)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
values = self.visit(node.value)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 934, in visit_Call
args = [self.visit(arg) for arg in node.args]
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 934, in
args = [self.visit(arg) for arg in node.args]
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 462, in visit_BinOp
lhs = self.visit(node.left)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 650, in visit_UnaryOp
op = self.visit(node.operand)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/share/database/code/esmfold/lib/python3.7/ast.py", line 271, in visit
return visitor(node)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 929, in visit_Call
static_implementation = self.statically_implemented_functions.get(fn)
TypeError: unhashable type: 'tensor'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "L_attention.py", line 22, in
o = lightning_attn_func(q, k, v, s)
File "/share/database/code/esmfold/lib/python3.7/site-packages/lightning_attn/ops/lightning_attn_interface.py", line 34, in lightning_attn_func
o = lightning_attn2(q, k, v, s)
File "/share/database/code/esmfold/lib/python3.7/site-packages/lightning_attn/ops/triton/lightning_attn2.py", line 429, in forward
BLOCK_MODEL=BLOCK_MODEL,
File "", line 63, in _fwd_kernel
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/compiler.py", line 381, in
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/share/database/code/esmfold/lib/python3.7/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 39:22: K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
S_block_ptr = S + off_h

##### init diag decay(Lambda); q, k decay; kv
s = tl.load(S_block_ptr)
# q, k decay
off_block = tl.arange(
    0, BLOCK
)  # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
                  ^

TypeError("unhashable type: 'tensor'")

assert d in supports_dim and e in supports_dim ?

Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?

I tried this script

import torch
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor

dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 2, 12, 2048, 192, 192

q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

o = lightning_attn_func(q, k, v, s)

print(o.shape)

and got this error

    o = lightning_attn_func(q, k, v, s)
  File "/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/lightning_attn_interface.py", line 10, in lightning_attn_func
    assert d in supports_dim and e in supports_dim
AssertionError

When running lightning_attn_func two or more times, an error occurred.

"In training, when I run lightning_attn_func two or more times, I encounter an exception with the content
“triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or num_stages may help.”
The partial code snippet of my execution is as follows
x = self.norm(self.attention(x,x,x,_build_slope_tensor(self.num_heads).to(x.device).to(torch.float32))) x = self.norm(self.ff(x))+x x = lightning_attn_func(x,x,x,_build_slope_tensor(self.num_heads).to(x.device).to(torch.float32))

The code stuck when running example_lightning_attn.py

Thanks for your excellent work. I have created the running environment as you described in readme.md. However, when I run the ./examples/ops/
example_lightning_attn.py, the code will get stuck in the lightning_attn_func function. Further debugging revealed that the problem lies in the _fwd_kernerl[grid] function. Could you please help me solve this problem? The Pytorch and Triton versions are 2.0.1 and 2.0.0, respectively and the GPU is Nvidia 3090Ti with Cuda 11.7. Could you please help me solve this problem? Looking forward to hearing from you. Thank you very much.

The methods for saving the Lightning-Attention model

I am using torch.save to save a model that contains the lightning_attn_func function. It seems that the model does not save the parameters related to lightning_attn_func. When I reload the model, I find that the results are inconsistent with those during the training process. Is this because the save method is incorrect, or is it for some reason that the save cannot be performed?

Tests fail

All the tests currently fail:

FAILED ops/test_lightning2.py::test_lightning2[dtype1-2-8-2048-128-64] - asse...
FAILED ops/test_lightning2.py::test_lightning2[dtype1-3-8-2048-128-64] - asse...
FAILED ops/test_lightning2.py::test_lightning2[dtype1-6-8-913-128-64] - asser...
FAILED ops/test_lightning2.py::test_lightning2[dtype1-6-8-513-128-64] - asser...
FAILED ops/test_lightning2.py::test_lightning2[dtype1-6-8-1213-128-64] - asse...
FAILED ops/test_lightning2.py::test_lightning2[dtype1-6-8-2048-16-64] - asser...

To reproduce, I checked out the repo, set up a conda env, pip install -e . and then:

cd tests
. ./script.sh

I also wrote my own unit test and it fails:

First differing element at index (0, 1, 1, 0): linear=6.259500503540039, lightning=6.514106273651123

Can the lightning_attn support V100?

import torch

from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor

dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 2, 12, 2048, 192, 192

q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

o = lightning_attn_func(q, k, v, s)

print(o.shape)

loss = o.sum()
loss.backward()

The above code cost lots of time. Eight minutes above. Is this duration normal?

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.