Closed TiminHu closed 2 months ago
@TiminHu Hi, thank you for reporting this bug. However we're not going to maintain this kernel in the future, and fused_chunk
will be marked as deprecated
in the next release.
You're recommended to use chunk
, which is faster than fused_chunk
in most scenarios.
In your case I suggest using fused_recurrent
In your case I suggest using
fused_recurrent
But when I use fused_recurrent
, the error still exists
@TiminHu Hi, refer to https://github.com/triton-lang/triton/issues/580
40960 is too large to launch.
Describe the bug
Hi, I encountered a problem while using the fused_chunk_gla function. My input data has a relatively short sequence length of 20, but the batch size can go up to 40960, with 2 attention heads and a dimension of 64. When I use such input, the function throws an error as shown above. However, when I reduce the batch size to 1024 or 256, it runs fine. It seems to be a memory issue, but I haven't observed any memory overuse on either the CPU or GPU. Do you have any idea why this might be happening?
Steps to reproduce the bug
q = torch.randn([40960,2,20,64]).cuda() k = torch.randn([40960,2,20,64]).cuda() v = torch.randn([40960,2,20,64]).cuda() gk = torch.randn([40960,2,20,64]).cuda() o, new_hidden_states = fused_chunk_gla(q, k, v, gk, output_final_state=True)
Expected behavior
I still hope that the input data can maintain a batch size of 40960 and run the function successfully.
Environment info