OpenNLPLab / lightning-attention

Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
MIT License
184 stars 15 forks source link

Cannot run the triton kernels #1

Closed jmercat closed 6 months ago

jmercat commented 9 months ago

Thanks for this repo, I'm pretty excited to test this out.

I drop-in replaced attention from lightning-attention in one of my projects and got the following:

RuntimeError: PassManager::run failed                                                                                                                                                                                                                                                                                                                                                                                
Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                                                   
  File "/opt/ml/code/open_lm/main.py", line 873, in <module>                                                                                                                                                                                                                                                                                                                                                         
main(sys.argv[1:])                                                                                                                                                                                                                                                                                                                                                                                                   
File "/opt/ml/code/open_lm/main.py", line 774, in main                                                                                                                                                                                                                                                                                                                                                               
success, global_step = train_one_epoch(                                                                                                                                                                                                                                                                                                                                                                              
  File "/opt/ml/code/open_lm/train.py", line 267, in train_one_epoch                                                                                                                                                                                                                                                                                                                                                 
backward(local_loss, scaler)                                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/ml/code/open_lm/train.py", line 92, in backward                                                                                                                                                                                                                                                                                                                                                         
total_loss.backward()                                                                                                                                                                                     
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward                                                                                                                  
torch.autograd.backward(                                                                                                                                                                                  
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward                                                                                                        
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                                                                                            
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply                                                                                                           
return user_fn(self, *args)                                                                                                                                                                               
  File "/lightning-attention/lightning_attn/ops/triton/lightning_attn2.py", line 462, in backward                                                                                                         
_bwd_intra_kernel[grid](                                                                                                                                                                                  
  File "<string>", line 63, in _bwd_intra_kernel                                                                                                                                                          
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile                                                                                                          
next_module = compile_kernel(module)                                                                                                                                                                      
  File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 383, in <lambda>                                                                                                       
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))                                                                                                                              
  File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir                                                                                                  
pm.run(mod)                                                                              

So I tried to simply run pytest tests/ops/test_lightning2.py And got only failures (it is weird that there is an assert False statement in there...) And the more worrisome result is that the errors are quite large...

tests/ops/test_lightning2.py FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF        [100%]

tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-256-128-64] tensor(0.1543, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.1650, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-512-128-64] tensor(0.2393, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2520, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-1024-128-64] tensor(0.3555, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3770, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-128-64] tensor(0.5117, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-4096-128-64] tensor(0.7344, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7773, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-8192-128-64] tensor(1.0391, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-32-64] tensor(0.2578, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-64-64] tensor(0.3633, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3848, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-12-2048-128-64] tensor(0.5234, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-16-2048-128-64] tensor(0.6719, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7109, device='cuda:0', dtype=torch.bfloat16)
FAILED
Doraemonzzz commented 9 months ago

Hi, thank you for providing the information.

It seems that the issue with the first question is most likely related to the version. The locally tested version that works fine is as follows.

╰─± pip list | grep triton
triton                   2.0.0
triton-nightly           2.1.0.dev20230728172942

You can use the following command to install the package:

pip install triton==2.0.0
pip install triton-nightly==2.1.0.dev20230728172942 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/

As for the second question, you can temporarily ignore it. Due to the inherent issues with Triton, numerical errors cannot be avoided. However, we have trained models using this kernel and compared them to the baseline (torch version), and there is almost no difference in loss. So, you can use it with confidence.

If you encounter any other issues, feel free to ask at any time.