huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.7k stars 26.22k forks source link

gemma2 + flash atten Error: RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long #32103

Open kkk935208447 opened 1 month ago

kkk935208447 commented 1 month ago

System Info

transformers: 4.43.0.dev0 torch: 2.3.0 deepspeed: 0.14.0 flash atten: 2.6.1

Due to the latest flash attention supporting logits soft capping, I have added the extension of the PI length of ROPE for gemma2 by referring to llama_model.py. I used the latest transformers version 4.43.0.dev0 + flash attention 2.6.1 to fine-tune gemma2, as I had previously fine-tuned gemma2 using transformers 4.42.3 (which used eager attention) and the inference effect was good. I thought the process would be smooth this time, but many problems have arisen during the use of this new package:

Tip: My sample lengths are generally around 11k - 30k.

  1. After loading gemma2, if the number of processes >= 2 when tokenizing the datasets, the progress will get stuck, while if gemma2 is not loaded, the datasets can execute normally, and even with a single process it can execute normally.

  2. Using deepspeed zero2 with flash attention enabled, it will report the error: RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long. After checking the official explanation, it seems that there are abnormal gradients, which is consistent with the third point below.

  3. Using deepspeed zero3 with flash attention enabled, the program will not report an error, but the initial loss value is over 2000.

  4. With all the above parameters unchanged, but only disabling flash attention and using eager attention, the code executes normally, and the initial loss is similar to my previous fine-tuning, starting to decrease from around 2.

In summary, it seems that there are some issues with the current integration of flash attention and Hugging Face. Personally, I think the extension of the PI length of ROPE is already stored in the query and key before the attention, so it should not be the cause of the impact. I hope Hugging Face can review whether these errors exist, and I greatly appreciate your work.

Who can help?

@ArthurZucker @muellerzr @SunMarc

Information

Tasks

Reproduction

RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long

[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 2274, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs)
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/trainer.py", line 3344, in training_step
[rank0]:     self.accelerator.backward(loss, **kwargs)
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/accelerator.py", line 2126, in backward
[rank0]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 175, in backward
[rank0]:     self.engine.step()
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2169, in step
[rank0]:     self._take_model_step(lr_kwargs)
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
[rank0]:     self.optimizer.step()
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step
[rank0]:     scaled_global_grad_norm = self.scaled_global_norm()
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm
[rank0]:     return torch.norm(torch.stack(norm_groups), p=norm_type)
[rank0]:   File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/torch/functional.py", line 1631, in norm
[rank0]:     return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
[rank0]: RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long

Expected behavior

tdrussell commented 1 month ago

I might be running into the same issue. I have a custom training script (qlora-pipe) that I'm using to train LoRA on gemma2 27b. Eager attention works well. Changing one line to make the implementation flash attention, and the weights and loss become NaN after a single training step is taken. I am using flash attention 2.6.1 and the latest Transformers commit.

@kkk935208447 Are you able to check if you have NaNs anywhere? In my case, if I don't explicitly check for NaN, the actual error that crashes my training script is that Tensorboard tries to write a histogram with 0 elements. Ultimately that is caused by the weights becoming NaN so the error you see might not be the root cause.

EDIT: I would also add that my training script is able to print the loss value on the very first step before it crashes, and it is ~2 which is expected. That implies that inference is working fine, since it can do the first forward pass and calculate the loss correctly. Only when it does the first backward pass do things become NaN.

kkk935208447 commented 1 month ago

@tdrussell Thank you very much for your reply, it makes me feel that the issue may not be due to my code.

I did not perform NAN loss value detection, I was interpreting it from the result perspective, and the RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long, which seems to have encountered an abnormal gradient in the deepspeed discussion. This may be caused by the factors you mentioned earlier.

I will take a look at whether the gradients in my forward propagation are normal when I have time.

kkk935208447 commented 1 month ago

@tdrussell I tried it, and it seems that the error is reported in the backpropagation process. My gradient accumulation is 16, and I can print all 16 forward propagations, but when they are combined for backpropagation, the above error is reported.

ArthurZucker commented 1 month ago

Hey all! I think we merged a few PRs to circumvent this, can you confirm that you are still having these issues?

tdrussell commented 1 month ago

Thanks! Latest Transformers with flash attention 2.6.3 works now, I can train gemma 2 27b.

There is still a problem though, and maybe it's an existing issue. Flash attention only gives correct loss values at <=4096 sequence length. This is true even for eval only, without training anything. For example, at 5120 sequence length, running evaluation in my training script with eager attention gives a loss of 2.2, while switching it to flash attention gives a loss of 3.6. Longer sequence length gives progressively worse loss values.

I think there is some problem with the interaction of flash attention and the sliding window of gemma 2. But at <=4096 context length everything including training appears to work correctly.

ArthurZucker commented 1 month ago

Interesting, if you are using the default HybridCache (so simply generate) this should normally not be the case.

YooSungHyun commented 3 weeks ago

use zero2 and overlap_comm=False, i work

kkk935208447 commented 3 weeks ago

Hey all! I think we merged a few PRs to circumvent this, can you confirm that you are still having these issues?

Thank you for the relevant work.