Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
232 stars 26 forks source link

Ask Help for OnlineSpatialNet Mamba Version Can't Work #26

Closed zx1292982431 closed 7 months ago

zx1292982431 commented 7 months ago

Aowesome job! I encountered some problems when trying to reproduce the OnlineSpeatialnet Mamba version. I hope to get your help. When I set the inference=False, the model can forward normally. But when I set the inference=True, it can't work. Here is the Traceback:

  Traceback (most recent call last):
    File "/mnt/raid2/user_space/lizixuan/projects/SpatialNet_Casual/models/arch/OnlineSpatialNet.py", line 418, in <module>
      res = model(x, inference=True).mean()
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
      result = forward_call(*args, **kwargs)
    File "/mnt/raid2/user_space/lizixuan/projects/SpatialNet_Casual/models/arch/OnlineSpatialNet.py", line 349, in forward
      x, attn = m(x, mask, chunkwise_recurrent, self.rope, None, inference)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
      result = forward_call(*args, **kwargs)
    File "/mnt/raid2/user_space/lizixuan/projects/SpatialNet_Casual/models/arch/OnlineSpatialNet.py", line 160, in forward
      x = x + self._mamba(x, self.mhsa, self.norm_mhsa, self.dropout_mhsa, inference)
    File "/mnt/raid2/user_space/lizixuan/projects/SpatialNet_Casual/models/arch/OnlineSpatialNet.py", line 179, in _mamba
      xi = mamba.forward(x[:, [i], :], inference_params)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 131, in forward
      out, _, _ = self.step(hidden_states, conv_state, ssm_state)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 248, in step
      y = selective_state_update(
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/mamba_ssm/ops/triton/selective_state_update.py", line 137, in selective_state_update
      _selective_scan_update_kernel[grid](
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
      return self.fn.run(*args, **kwargs)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
      return self.fn.run(*args, **kwargs)
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
      return self.fn.run(*args, **kwargs)
    [Previous line repeated 1 more time]
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
      bin.c_wrapper(
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in __getattribute__
      self._init_handles()
    File "/home/lizixuan/miniconda3/envs/SpatialNet/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _init_handles
      mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
  RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Additionally, I did that on single V100(32G) gpu, and here are my environment configuration:

python == 3.10.14
torch == 2.2.2+cu118
causal-conv1d == 1.2.0.post2
mamba-sim == 1.2.0.post1

My WeChat ID is zx1292982431, if it can make our communication more convenient.

zx1292982431 commented 7 months ago

Add:

  1. I didn't use torch.complie
  2. mamba-sim == 1.2.0.post1 mamba-ssm == 1.2.0.post1
quancs commented 7 months ago

Thank you for your interest in our work. Accutually, inference=True is only used for measuring the FLOPs in our experiments, and inference=True or inference=False will produce the same result theoretically. So, it doesn't matter if you set inference=False in training and evaluation.

zx1292982431 commented 7 months ago

Thank you for your interest in our work. Accutually, inference=True is only used for measuring the FLOPs in our experiments, and inference=True or inference=False will produce the same result theoretically. So, it doesn't matter if you set inference=False in training and evaluation.

Tankyou!