sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.34k stars 69 forks source link

RuntimeError: Triton Error [CUDA]: invalid argument #64

Closed TiminHu closed 2 months ago

TiminHu commented 2 months ago

Describe the bug

Hi, I encountered a problem while using the fused_chunk_gla function. My input data has a relatively short sequence length of 20, but the batch size can go up to 40960, with 2 attention heads and a dimension of 64. When I use such input, the function throws an error as shown above. However, when I reduce the batch size to 1024 or 256, it runs fine. It seems to be a memory issue, but I haven't observed any memory overuse on either the CPU or GPU. Do you have any idea why this might be happening?

Steps to reproduce the bug

q = torch.randn([40960,2,20,64]).cuda() k = torch.randn([40960,2,20,64]).cuda() v = torch.randn([40960,2,20,64]).cuda() gk = torch.randn([40960,2,20,64]).cuda() o, new_hidden_states = fused_chunk_gla(q, k, v, gk, output_final_state=True)

Expected behavior

I still hope that the input data can maintain a batch size of 40960 and run the function successfully.

Environment info

  1. torch: 2.4.1
  2. triton: 3.0.0
yzhangcs commented 2 months ago

@TiminHu Hi, thank you for reporting this bug. However we're not going to maintain this kernel in the future, and fused_chunk will be marked as deprecated in the next release. You're recommended to use chunk, which is faster than fused_chunk in most scenarios.

sustcsonglin commented 2 months ago

In your case I suggest using fused_recurrent

TiminHu commented 2 months ago

In your case I suggest using fused_recurrent

But when I use fused_recurrent, the error still exists

yzhangcs commented 2 months ago

@TiminHu Hi, refer to https://github.com/triton-lang/triton/issues/580

40960 is too large to launch.