sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.33k stars 68 forks source link

Add implementations of Mamba 2 into FLA #34

Closed DanFosing closed 3 months ago

DanFosing commented 3 months ago

Hello! Do you plan to add Mamba 2 to your repo? If so, any estimate on when we can expect it?

yzhangcs commented 3 months ago

@DanFosing Thanks for your attention. Yes, we indeed plan to add the kernels & models into fla. @sustcsonglin is playing with it, please stay tuned.

DanFosing commented 3 months ago

@yzhangcs I made a pull request #39 implementing mamba 2 (modeling.py file made for mamba codestral)

yzhangcs commented 2 months ago

@DanFosing We finally have the simple GLA / Gated RetNet kernel, which is compatible with and significantly faster than Mamba2, thanks to the great job by you and @learning-chip (#39, #49 and #50).

SmerkyG commented 2 months ago

@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)

sustcsonglin commented 2 months ago

@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)

Thanks for bug catching! Just fixed

SmerkyG commented 2 months ago

Works great, thanks!!!

One more quick note: when using torch_simple_gla I get NaN after a few iterations... with chunk_simple_gla I also get NaN fairly quickly but it takes a while longer, around 325 iterations. Training with my simple implementation this doesn't happen.

simple implementation pseudocode:

def segsum(w_log): # B H L 1
    w_log_cumsum = torch.cumsum(w_log, dim=-2) # (B, H, L, 1)
    return torch.exp((w_log_cumsum - w_log_cumsum.mT).tril()).tril() # (B, H, L, L)

att = (q * q.size(-1)**-0.5) @ k.mT
att = att * segsum(w_log) # segsum handles zeroing the upper right tri
out = att @ v

Update: it works fine as long as I clamp the g (w_log) values to -5 or so... I guess you must be using the original GLA method to calculate this via relative changes to q,k so there's a precision limit.

SmerkyG commented 2 months ago

I also get an error with num_warps being an unrecognized argument when using torch.compile on chunk_simple_gla:


[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank0]:     return compiled_graph.compiled_artifact(inputs)
[rank0]:   File "/tmp/torchinductor_recursal/uz/cuz4byetjs4tgjziqso7bh6qlsoayncgrauaw7aei7y56e5t2re7.py", line 334, in call
[rank0]:     chunk_simple_gla_fwd_kernel_h_0.run(k=buf2, v=buf3, h=buf0, g=reinterpret_tensor(buf1, (32, 12, 512), (6144, 512, 1), 0), h0=arg3_1, ht=None, s_qk_h=32768, s_qk_t=64, s_qk_d=1, s_vo_h=32768, s_vo_t=64, s_vo_d=1, s_h_h=32768, s_h_t=64, T=512, K=64, V=64, BT=64, BK=64, BV=64, NT=8, USE_INITIAL_STATE=True, STORE_FINAL_STATE=False, num_warps=4, num_stages=1, grid=grid_wrapper_for_chunk_simple_gla_fwd_kernel_h_0, stream=stream0)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/triton_heuristics.py", line 670, in run
[rank0]:     return launcher(
[rank0]: TypeError: launcher() got an unexpected keyword argument 'num_warps'```