RulinShao / LightSeq

Official repository for LightSeq: Sequence Level Parallelism for Distributed Training of Long Context Transformers
179 stars 8 forks source link

Runtime Error when testing on H100 #8

Closed yihaocs closed 4 months ago

yihaocs commented 4 months ago

Hi,

Thanks for the great work and the code. I tried to run the test on H100, but it gave me the following error:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/export/home/research/mm/LightSeq/lightseq/lightseq_async_attn.py", line 740, in <module>
[rank7]:     test_op(1, 16, N_CTX, 128, True)
[rank7]:   File "/export/home/research/mm/LightSeq/lightseq/lightseq_async_attn.py", line 501, in test_op
[rank7]:     tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half()
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/home/miniconda3/lib/python3.11/site-packages/torch/autograd/function.py", line 606, in apply
[rank7]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/home/research/mm/LightSeq/lightseq/lightseq_async_attn.py", line 447, in forward
[rank7]:     q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode)
[rank7]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/home/research/mm/LightSeq/lightseq/lightseq_async_attn.py", line 304, in _lightseq_forward
[rank7]:     fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step))
[rank7]:   File "/export/home/research/mm/LightSeq/lightseq/lightseq_async_attn.py", line 271, in <lambda>
[rank7]:     fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid](
[rank7]:                                                                           ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/home/miniconda3/lib/python3.11/site-packages/triton/runtime/jit.py", line 180, in <lambda>
[rank7]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank7]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/home/miniconda3/lib/python3.11/site-packages/triton/runtime/jit.py", line 412, in run
[rank7]:     kernel.run(grid_0, grid_1, grid_2, metadata.num_warps,
[rank7]:     ^^^^^^^^^^
[rank7]:   File "/export/home/miniconda3/lib/python3.11/site-packages/triton/compiler/compiler.py", line 339, in __getattribute__
[rank7]:     self._init_handles()
[rank7]:   File "/export/home/miniconda3/lib/python3.11/site-packages/triton/compiler/compiler.py", line 332, in _init_handles
[rank7]:     raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
[rank7]: triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 233472, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.

The Triton and pytorch packages I am using the latest nightly version installed from Pytorch website, and flash-attn is installed from source:

triton.__version__ == 3.0.0+989adb9a29
torch.__version__==2.4.0.dev20240325+cu121
flash_attn.__version__ == 2.5.6

Do you have any idea what the problem could be?

Thanks!

linyubupa commented 4 months ago

same problem on 8*A100 flash-attn 2.5.6 triton 2.2.0 torch 2.2.2

DachengLi1 commented 4 months ago

Interesting, we have tested this in A100.

But this is because the block size is too large so it does not fit in SRAM. You can reduce the block size to make it fit: https://github.com/RulinShao/LightSeq/blob/main/lightseq/lightseq_async_attn.py#L252 e.g. reduce this to 64, 64. If you encounter the same problem in backward, you can reduce the backward block size as well: https://github.com/RulinShao/LightSeq/blob/main/lightseq/lightseq_async_attn.py#L342

yihaocs commented 4 months ago

Thanks! But looks weird if we can get it to work for A100 not for H100.