state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.45k stars 1.05k forks source link

_layer_norm_fwd_1pass_kernel error #84

Open chenwuchen opened 8 months ago

chenwuchen commented 8 months ago

Title: Error when running multi-GPU training with Mamba

Description: I am experiencing an issue when running multi-GPU training with Mamba. Specifically, I am getting a TypeError: 'NoneType' object is not a mapping error when running the forward pass of the model. The error occurs when I try to run the model on multiple GPUs using the DataParallel module. However, when I run the model on a single GPU, everything works fine.

I have tried to reproduce the issue with a minimal example, but I was unable to do so. I have also checked the documentation and searched online for similar issues, but I couldn't find anything useful.

Here is the full traceback of the error:

Traceback (most recent call last): File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, kwargs) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward hidden_states, residual = layer( File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward hidden_states, residual = fused_add_norm_fn( File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward y, mean, rstd, residual_out = _layer_norm_fwd( File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run timings = {config: self._bench(*args, config=config, *kwargs) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in timings = {config: self._bench(args, config=config, kwargs) File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 75, in _bench full_nargs = {self.nargs, **current} TypeError: 'NoneType' object is not a mapping

I am using Python 3.10, PyTorch 1.12.1, and causal_conv1d-1.1.1 mamba-ssm-1.1.1 triton-2.1.0

akkikiki commented 8 months ago

@chenwuchen Bumped into the same error. Solution: Use DDP or torchrun.

mhamzaerol commented 7 months ago

I was facing with the same issue. Upon investigating further, I realized that the line 75 of the autotuner.py file in the triton package receives self.nargs=None, which casues problems with generating a dictionary from it:

(inside the _bench method)

full_nargs = {**self.nargs, **current}

My take is that this may:

But, I was able to run the dataparallel by replacing the line 75 of autotuner.py with these modifications:

full_nargs = {}
if self.nargs:
    full_nargs.update(self.nargs)
if current:
    full_nargs.update(current)

Though, not sure if the training would overall behave as expected.

PheelaV commented 6 months ago

Needed to do the same, I am executing the trainer script from mamba_chat and got a triton error same as above. Patching the package and installing:

Get the package

git clone https://github.com/openai/triton.git;
git checkout release/2.1.x;
pip install cmake;

Patch it by editing python/triton/runtime/autotuner.py at line 75

replacing

full_nargs = {**self.nargs, **current}

with

full_nargs = {}
if self.nargs:
    full_nargs.update(self.nargs)
if current:
    full_nargs.update(current)

proceed to install the patched version:

cd triton/python;
pip install -e .

install the rest of mamba dependencies as per normal

(I am using Pyton 3.11 in a conda environment, currently have training running on a pair of RTX 3090s)

s22chan commented 5 months ago

Hacking it that way could cause silent errors (especially if different args are passed concurrently to the jit)

Looks like DataParallel is multi-threaded and Triton doesn't appear to be thread-safe.

If you're confident the rest of the forward pass is thread-safe and you must run in threaded mode, you could try to run a single pass first to boostrap the jit before running it in parallel. I don't think mamba has a config.pre_hook, which is the only thing that would be written per run after benching is complete.

AndssY commented 5 months ago

@PheelaV Can you provide more details about how to conduct triton from source? Thanks! I can't install triton with pyton 3.8.5 in a conda environment.

.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/IR/Value.h:95:56: error: ‘void* __builtin_memset(void*, int, long unsigned int)’ specified size between 18446744039349813224 and 18446744073709551608 exceeds maximum object size 9223372036854775807 [-Werror=stringop-overflow=]
         95 |   constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
...
        270 |   iterator begin() { return (iterator)this->BeginX; }
            |                                       ~~~~~~^~~~~~
      At global scope:
      cc1plus: note: unrecognized command-line option ‘-Wno-covered-switch-default’ may have been intended to silence earlier diagnostics
      cc1plus: all warnings being treated as errors
      gmake[2]: *** [lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/build.make:163: lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp.o] Error 1
      gmake[2]: Leaving directory 'triton/python/build/cmake.linux-x86_64-cpython-3.8'
      gmake[1]: *** [CMakeFiles/Makefile2:2299: lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/all] Error 2
      gmake[1]: Leaving directory 'triton/python/build/cmake.linux-x86_64-cpython-3.8'
      gmake: *** [Makefile:149: all] Error 2
...
ERROR: Could not build wheels for triton, which is required to install pyproject.toml-based projects
PheelaV commented 5 months ago

@AndssY Sorry I can't really, I was using python 3.10 or .11 on Ubuntu 22.04 LTS and following their readme instructions everything worked out. I think the only thing I had to do out of standard was to follow the specific release branch as requested per mamba dependencies (important).

But this whole thing became redundant. I think Mamba was patched and everything suddenly started to work with just mamba-ssm and conv1d install. I still kept the environment with a triton built from source to be sure, but it was no longer necessary for me.

Hope that helps at least a little bit. Good luck with getting it up and running. Feel free to DM me if you still struggle, I think I went through all the jumps and hoops I could have met.

AndssY commented 5 months ago

...and following their readme instructions everything worked out. I think the only thing I had to do out of standard was to follow the specific release branch as requested per mamba dependencies (important).

@PheelaV Did you install according to the readme of maba-chat? So mamba-ssm==1.0.1 and triton==release/2.1.x?

I will try python==3.11 and install it again following the readme of mamba-chat. Thanks very much!