sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.24k stars 66 forks source link

[Bug]: new autotune error running simple gla chunked #67

Closed SmerkyG closed 1 week ago

SmerkyG commented 1 week ago

Describe the bug

Hi,

Sorry to bother you, I recently updated FLA on an 8xH100 machine and it now gives new errors during autotuning with fla.ops.simple_gla.chunk that were not present previously. It must be the result of a fairly recent change, though unfortunately I don't know exactly which commit was the one that worked previously. fla.ops.gla.fused_chunk continues to work fine for me, but fla.ops.gla.chunk gives me the same kind of autotune index error on a different machine (8x4090).

It appears to be some autotune setting issue, maybe relating to there being keys supplied but no related args...

I also believe once this is fixed there is a regression that occurred from the fix for bug https://github.com/sustcsonglin/flash-linear-attention/issues/58, since I now see num_warps=8 present in the new code

This is the error I'm seeing:

    attn_output = fla_chunk_simple_gla(query_states, key_states, value_states, decay_states_log.view(bsz, self.num_heads, q_len))[0]
  File "/workspace/smerky/GoldFinch-paper/models/qwen2.py", line 39, in fla_chunk_simple_gla
    o, final_state = SimpleGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/fla/utils.py", line 12, in wrapper
    return fn(ctx,
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/fla/ops/simple_gla/chunk.py", line 334, in forward
    h, final_state = chunk_fwd_h_fn(k=k,v=v,g=g,gk=None,gv=None,BT=BT,h0=initial_state,output_final_state=output_final_state,states_in_fp32=False)
  File "/usr/local/lib/python3.10/dist-packages/fla/ops/common/chunk_h.py", line 214, in chunk_fwd_h_fn
    chunk_fwd_kernel_h[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 146, in run
    key = [_args[i] for i in self.key_idx]
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 146, in <listcomp>
    key = [_args[i] for i in self.key_idx]
IndexError: list index out of range

Steps to reproduce the bug

I am using the following calling code:

from fla.ops.simple_gla.chunk import chunk_simple_gla, SimpleGLAFunction

def fla_chunk_simple_gla(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,  # log decay
) -> Tuple[torch.Tensor, torch.Tensor]:
    assert q.dim() == k.dim() == v.dim() == 4, "q, k, v must have 4 dimensions (b, h, l, d)"
    assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype"
    scale = k.shape[-1] ** -0.5
    g = g.float()
    initial_state = None
    output_final_state = False
        o, final_state = SimpleGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
    return o, final_state

Expected behavior

no error :)

Environment info

  1. torch: 2.4.1
  2. triton: 3.0.0 8xH100
sustcsonglin commented 1 week ago

Hi, which commit are you using?

sustcsonglin commented 1 week ago

We've recently fixed a bug that raised this error. The most recent one should be good

SmerkyG commented 1 week ago

Thanks, I'll update and see! There is probably still a regression from earlier bug https://github.com/sustcsonglin/flash-linear-attention/issues/58 where H100 will die on num_warps=8 but I will let you know once I try it!

SmerkyG commented 1 week ago

Works great! Sorry for the false alarm. So far no problem on H100 - I'll reopen an issue if any of the kernels fail on that. Thanks!!!