state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.3k stars 1.13k forks source link

CUDA error when using Mamba2 with long context #503

Open titzehong opened 3 months ago

titzehong commented 3 months ago

Hi, I am benchmarking inference speed on long sequences and encountering CUDA-related errors specifically with the Mamba2 models at longer sequence lengths (>200k). This issue does not occur with Mamba1 models.

For example running:

python benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-1.3b" --promptlen 300000 --genlen 1

produces the error:

File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 260, in generate output = decode( File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 221, in decode scores.append(get_logits(sequences[-1], inference_params)) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 184, in get_logits logits = model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward hidden_states = self.backbone(input_ids, inference_params=inference_params, mixer_kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward hidden_states, residual = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward hidden_states = self.mixer(hidden_states, inference_params=inference_params, mixer_kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 242, in forward y = mamba_chunk_scan_combined( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 581, in mamba_chunk_scan_combined return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply return super().apply(*args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 540, in forward out, out_x, dt_out, dA_cumsum, states, final_states, rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd _chunk_cumsum_fwd_kernel[grid_chunk_cs]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 326, in return lambda args, kwargs: self.run(grid=grid, warmup=False, *args, kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run timings = {config: self._bench(*args, config=config, *kwargs) for config in pruned_configs} File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in timings = {config: self._bench(args, config=config, **kwargs) for config in pruned_configs} File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 108, in do_bench torch.cuda.synchronize() File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 801, in synchronize return torch._C._cuda_synchronize() RuntimeError: CUDA error: an illegal memory access was encountered

This issue seems to only occur with Mamba2 models and is present across models of all different sizes. Mamba1 however works well and i am able to do inference on prompt lengths of up to 1m on the 1.4b model.

I am using a single h100 (80gb) card.

Thanks!

tridao commented 3 months ago

Are any tensors of size >= 2GB? We use int32 for indexing, it's possible that it wraps around the max of int32 and produce negative index, causing IMA.

titzehong commented 3 months ago

Yes i have several intermediate output tensors exceeding 2GB in size. However this is also true for lower context lengths that do not produce the error. For example, context length 250k which runs fine has almost all layer outputs exceeding 2GB but it runs fine.

Noted on the indexing causing the issue, is one solution to lower the models dimension?

Hprairie commented 3 months ago

I think I found the problem, I have submitted a PR.

iofu728 commented 1 month ago

+1.

Hprairie commented 1 month ago

Try the fix I did in the PR, it's 4 lines of code change, lmk if that works?

serendipityCoding commented 1 month ago

+1. I tried the fix in the PR, not working either.

LuJunru commented 1 month ago

I found this is related to triton issue. Modify tl.program_id(*) to tl.program_id(*).to(tl.int64) can skip this error: https://github.com/triton-lang/triton/issues/1058.