state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.44k stars 1.05k forks source link

Triton Error #370

Open Kevin-naticl opened 3 months ago

Kevin-naticl commented 3 months ago

Here is my error log. When I run it on one GPU, it works fine. But when I start to train it show errors below: TypeError: Caught TypeError in replica 1 on device 1. Original Traceback (most recent call last): File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(*input, **kwargs) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/work_nfs12/bykang/workspace/stcm_mss_new/model/HGCN_dprnn_mamba_v5.py", line 624, in forward out = self.full_mamba(feature_map) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/work_nfs12/bykang/workspace/stcm_mss_new/model/HGCN_dprnn_mamba_v5.py", line 393, in forward out = self.enhance(out) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/work_nfs12/bykang/workspace/stcm_mss_new/model/HGCN_dprnn_mamba_v5.py", line 272, in forward x = self.intra_RNN(x) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/work_nfs12/bykang/workspace/stcm_mss_new/model/HGCN_dprnn_mamba_v5.py", line 148, in forward forward_f, for_residual = block(forward_f, for_residual, inference_params=None) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/modules/block.py", line 53, in forward hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 944, in forward return rms_norm_fn( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 908, in rms_norm_fn return LayerNormFn.apply( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 748, in forward y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 335, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/triton/runtime/jit.py", line 167, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 153, in run full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} TypeError: 'NoneType' object is not a mapping

Kevin-naticl commented 3 months ago

My Torch is 2.0.1 and my trition is 2.3.1. CUDA 12.4 , on GPU4090

tridao commented 3 months ago

Yeah looks like triton error. Idk what's wrong, I'm not an expert in triton

Kevin-naticl commented 3 months ago

Yeah looks like triton error. Idk what's wrong, I'm not an expert in triton

Could it be the problem between trition versions and torch versions? I see similar errors last year with Mamba1

Kevin-naticl commented 3 months ago

This issue: https://github.com/state-spaces/mamba/issues/84

Kevin-naticl commented 3 months ago

Use Torch2.1 and Trition 2.1 seems to solve the problem.

Kevin-naticl commented 3 months ago

Damn, I thought it was solved, it was my data loader 's problem.

Kevin-naticl commented 3 months ago

Temporarily solve the problem with the solution below mentioned in https://github.com/state-spaces/mamba/issues/84

Patch it by editing python/triton/runtime/autotuner.py at line 75 replacing

full_nargs = {self.nargs, current}

with

full_nargs = {} if self.nargs: full_nargs.update(self.nargs) if current: full_nargs.update(current)

But it still have new problem:

loss.backward() # File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd return bwd(args, **kwargs) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssdcombined.py", line 871, in backward dx, ddt, dA, dB, dC, dD, , ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 403, in _mamba_chunk_scan_combined_bwd dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 245, in _chunk_scan_chunk_state_bwd_dx _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( File "/home/environment2/bykang/anaconda3/envs/mamba/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 116, in run config = self.configs[0] TypeError: 'NoneType' object is not a mapping

Kevin-naticl commented 3 months ago

right now the trition version is 2.1.0 and torch is 2.1.0

DanFosing commented 4 weeks ago

Did anyone manage to fix this issue? If so, what did you do?