redotvideo / mamba-chat

Mamba-Chat: A chat LLM based on the state-space model architecture 🐍
Apache License 2.0
903 stars 69 forks source link

Error during training #15

Closed Eupham closed 8 months ago

Eupham commented 9 months ago

I can get the model to perform inference just fine, but in my colab env this is what I run into

`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Tokenizing dataset... 100% 1000/1000 [00:02<00:00, 383.23it/s] 0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last): File "/content/mamba-chat/train_mamba.py", line 60, in run(args) File "/content/mamba-chat/train_mamba.py", line 45, in run trainer.train() File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train return inner_training_loop( File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step loss = self.compute_loss(model, inputs) File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss lm_logits = model(input_ids).logits File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward hidden_states = self.backbone(input_ids, inference_params=inference_params) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward hidden_states, residual = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward hidden_states, residual = fused_add_norm_fn( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward y, mean, rstd, residual_out = _layer_norm_fwd( File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run timings = {config: self._bench(*args, config=config, *kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in timings = {config: self._bench(args, config=config, kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench fn() File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) File "", line 63, in _layer_norm_fwd_1pass_kernel File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile next_module = compile_kernel(module) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in lambda src: ptx_to_cubin(src, arch)) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin return compile_ptx_to_cubin(ptx, ptxas, arch) RuntimeError: Internal Triton PTX codegen error: ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher .... ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher ptxas fatal : Ptx assembly aborted due to errors

0% 0/750 [00:02<?, ?it/s]`

Rohith04MVK commented 8 months ago

I don't know if it's too late, but changing the dtype from bfloat16 to just float16 here does seem to fix the problem on colab. bfloats are not supported on T4, I think.

Edit: Setting the dtype to float16 seems to have broken the loss, but after setting it to float32 everything worked fine.

Eupham commented 8 months ago

I'll give this a try later. Thanks

Eupham commented 8 months ago

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia

Also under training args adding fp16=True, results in faster training in colab

Rohith04MVK commented 8 months ago

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia

ya that looks like a colab issue?

Eupham commented 8 months ago

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

Rohith04MVK commented 8 months ago

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

Found an issue open over at pytorch, looks like a version issue with triton and pytorch(?). Don't think the maintainers will be happy with us discussing that here.