state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.41k stars 1.04k forks source link

mamba-ssm-2.2.0 has an error #431

Open LoserCheems opened 2 months ago

LoserCheems commented 2 months ago

The following error occurs when using mamba_chunk_scan_combined in mamba-ssm-2.2.0: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: function MambaChunkScanCombinedFnBackward returned an incorrect number of gradients (expected 16, got 14) However, mamba-ssm-2.1.0 does not have this problem.

tridao commented 2 months ago

Thanks for the bug report, I've just pushed a fix. Will take 1-2h for the wheels to compile.