redotvideo / mamba-chat

Mamba-Chat: A chat LLM based on the state-space model architecture 🐍
Apache License 2.0
911 stars 69 forks source link

Error during training #15

Closed Eupham closed 10 months ago

Eupham commented 11 months ago

I can get the model to perform inference just fine, but in my colab env this is what I run into

`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Tokenizing dataset... 100% 1000/1000 [00:02<00:00, 383.23it/s] 0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last): File "/content/mamba-chat/train_mamba.py", line 60, in run(args) File "/content/mamba-chat/train_mamba.py", line 45, in run trainer.train() File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train return inner_training_loop( File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step loss = self.compute_loss(model, inputs) File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss lm_logits = model(input_ids).logits File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward hidden_states = self.backbone(input_ids, inference_params=inference_params) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward hidden_states, residual = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward hidden_states, residual = fused_add_norm_fn( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward y, mean, rstd, residual_out = _layer_norm_fwd( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run timings = {config: self._bench(*args, config=config, *kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in timings = {config: self._bench(args, config=config, kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench fn() File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) File "", line 63, in _layer_norm_fwd_1pass_kernel File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile next_module = compile_kernel(module) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in lambda src: ptx_to_cubin(src, arch)) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin return compile_ptx_to_cubin(ptx, ptxas, arch) RuntimeError: Internal Triton PTX codegen error: ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher .... ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas fatal : Ptx assembly aborted due to errors

0% 0/750 [00:02<?, ?it/s]`

Rohith04MVK commented 10 months ago

I don't know if it's too late, but changing the dtype from bfloat16 to just float16 here does seem to fix the problem on colab. bfloats are not supported on T4, I think.

Edit: Setting the dtype to float16 seems to have broken the loss, but after setting it to float32 everything worked fine.

Eupham commented 10 months ago

I'll give this a try later. Thanks

Eupham commented 10 months ago

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia

Also under training args adding fp16=True, results in faster training in colab

Rohith04MVK commented 10 months ago

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia

ya that looks like a colab issue?

Eupham commented 10 months ago

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

Rohith04MVK commented 10 months ago

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

Found an issue open over at pytorch, looks like a version issue with triton and pytorch(?). Don't think the maintainers will be happy with us discussing that here.