Closed Eupham closed 10 months ago
I don't know if it's too late, but changing the dtype from bfloat16
to just float16
here does seem to fix the problem on colab. bfloats are not supported on T4, I think.
Edit: Setting the dtype to float16
seems to have broken the loss, but after setting it to float32
everything worked fine.
I'll give this a try later. Thanks
It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia
Also under training args adding fp16=True, results in faster training in colab
It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia
ya that looks like a colab issue?
It is. Honestly I got that solution from here. I haven't looked up why it works yet.
I can get the model to perform inference just fine, but in my colab env this is what I run into
`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Tokenizing dataset... 100% 1000/1000 [00:02<00:00, 383.23it/s] 0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last): File "/content/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/content/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench( args, config=config, kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in
lambda src: ptx_to_cubin(src, arch))
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin
return compile_ptx_to_cubin(ptx, ptxas, arch)
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher
....
ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors
0% 0/750 [00:02<?, ?it/s]`