state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.44k stars 1.05k forks source link

Backward error in triton #408

Open mahao18cm opened 2 months ago

mahao18cm commented 2 months ago

When i train mamba2 , i firstly got a error like that: FileNotFoundError: [Errno 2] No such file or directory: '/root/.triton/cache/2809997ca5d0d9cfa31a77d0e143bb8b/_bmm_chunk_fwd_kernel.cubin.tmp.pid_10278_13620' Then i use try: os.replace(temp_path, filepath) except: pass to solve the problems. However, when i train again , the different bug comes. Traceback (most recent call last): File "/root/FlowFormer-Official/train_FlowFormer.py", line 176, in train(cfg) File "/root/FlowFormer-Official/train_FlowFormer.py", line 98, in train scaler.scale(loss).backward() File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward torch.autograd.backward( File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward _engine_run_backward( File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply return user_fn(self, args) File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 142, in decorate_bwd return bwd(args, kwargs) File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssdcombined.py", line 871, in backward dx, ddt, dA, dB, dC, dD, , ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 403, in _mamba_chunk_scan_combined_bwd dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 245, in _chunk_scan_chunk_state_bwd_dx _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in return lambda *args, *kwargs: self.run(grid=grid, warmup=False, args, kwargs) File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 168, in run config.pre_hook(full_nargs) File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssdcombined.py", line 46, in return lambda nargs: [nargs[name].zero() for name in names if nargs[name] is not None] File "/root/miniconda3/envs/flownet2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssdcombined.py", line 46, in return lambda nargs: [nargs[name].zero() for name in names if nargs[name] is not None] KeyError: 'ddt_ptr' But in the _chunk_scan_chunk_state_bwd_dx_kernel ddt_ptr is defined. So who can solve this problems?

def _chunk_scan_chunk_state_bwd_dx_kernel(

Pointers to matrices

x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
b_ptr, dstates_ptr,
dx_ptr, ddt_ptr, dD_ptr,
# Matrix dimensions
chunk_size, hdim, dstate,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_D_head,
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,

)

hanmiaolingit commented 1 month ago

same question in multi_gpu training ,if in single gpu ,no this question

Peilin-FF commented 1 month ago

me too, when I use multi_gpu to train my model , this problem occurred