Closed DanFosing closed 3 months ago
@DanFosing Thanks for your attention. Yes, we indeed plan to add the kernels & models into fla
. @sustcsonglin is playing with it, please stay tuned.
@yzhangcs I made a pull request #39 implementing mamba 2 (modeling.py file made for mamba codestral)
@DanFosing We finally have the simple GLA / Gated RetNet kernel, which is compatible with and significantly faster than Mamba2, thanks to the great job by you and @learning-chip (#39, #49 and #50).
@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)
@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get
RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)
Thanks for bug catching! Just fixed
Works great, thanks!!!
One more quick note: when using torch_simple_gla I get NaN after a few iterations... with chunk_simple_gla I also get NaN fairly quickly but it takes a while longer, around 325 iterations. Training with my simple implementation this doesn't happen.
simple implementation pseudocode:
def segsum(w_log): # B H L 1
w_log_cumsum = torch.cumsum(w_log, dim=-2) # (B, H, L, 1)
return torch.exp((w_log_cumsum - w_log_cumsum.mT).tril()).tril() # (B, H, L, L)
att = (q * q.size(-1)**-0.5) @ k.mT
att = att * segsum(w_log) # segsum handles zeroing the upper right tri
out = att @ v
Update: it works fine as long as I clamp the g (w_log) values to -5 or so... I guess you must be using the original GLA method to calculate this via relative changes to q,k so there's a precision limit.
I also get an error with num_warps being an unrecognized argument when using torch.compile on chunk_simple_gla:
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank0]: return compiled_graph.compiled_artifact(inputs)
[rank0]: File "/tmp/torchinductor_recursal/uz/cuz4byetjs4tgjziqso7bh6qlsoayncgrauaw7aei7y56e5t2re7.py", line 334, in call
[rank0]: chunk_simple_gla_fwd_kernel_h_0.run(k=buf2, v=buf3, h=buf0, g=reinterpret_tensor(buf1, (32, 12, 512), (6144, 512, 1), 0), h0=arg3_1, ht=None, s_qk_h=32768, s_qk_t=64, s_qk_d=1, s_vo_h=32768, s_vo_t=64, s_vo_d=1, s_h_h=32768, s_h_t=64, T=512, K=64, V=64, BT=64, BK=64, BV=64, NT=8, USE_INITIAL_STATE=True, STORE_FINAL_STATE=False, num_warps=4, num_stages=1, grid=grid_wrapper_for_chunk_simple_gla_fwd_kernel_h_0, stream=stream0)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/triton_heuristics.py", line 670, in run
[rank0]: return launcher(
[rank0]: TypeError: launcher() got an unexpected keyword argument 'num_warps'```
Hello! Do you plan to add Mamba 2 to your repo? If so, any estimate on when we can expect it?