Open titzehong opened 3 months ago
Are any tensors of size >= 2GB? We use int32 for indexing, it's possible that it wraps around the max of int32 and produce negative index, causing IMA.
Yes i have several intermediate output tensors exceeding 2GB in size. However this is also true for lower context lengths that do not produce the error. For example, context length 250k which runs fine has almost all layer outputs exceeding 2GB but it runs fine.
Noted on the indexing causing the issue, is one solution to lower the models dimension?
I think I found the problem, I have submitted a PR.
+1.
Try the fix I did in the PR, it's 4 lines of code change, lmk if that works?
+1. I tried the fix in the PR, not working either.
I found this is related to triton issue. Modify tl.program_id(*)
to tl.program_id(*).to(tl.int64)
can skip this error: https://github.com/triton-lang/triton/issues/1058.
Hi, I am benchmarking inference speed on long sequences and encountering CUDA-related errors specifically with the Mamba2 models at longer sequence lengths (>200k). This issue does not occur with Mamba1 models.
For example running:
produces the error:
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 260, in generate output = decode( File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 221, in decode scores.append(get_logits(sequences[-1], inference_params)) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 184, in get_logits logits = model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward hidden_states = self.backbone(input_ids, inference_params=inference_params, mixer_kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward hidden_states, residual = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward hidden_states = self.mixer(hidden_states, inference_params=inference_params, mixer_kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 242, in forward y = mamba_chunk_scan_combined( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 581, in mamba_chunk_scan_combined return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply return super().apply(*args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 540, in forward out, out_x, dt_out, dA_cumsum, states, final_states, rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd _chunk_cumsum_fwd_kernel[grid_chunk_cs]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 326, in
return lambda args, kwargs: self.run(grid=grid, warmup=False, *args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, *kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in
timings = {config: self._bench( args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 108, in do_bench
torch.cuda.synchronize()
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 801, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
This issue seems to only occur with Mamba2 models and is present across models of all different sizes. Mamba1 however works well and i am able to do inference on prompt lengths of up to 1m on the 1.4b model.
I am using a single h100 (80gb) card.
Thanks!