Open ArcherShirou opened 2 months ago
Did you set fp16 = True
or bf16 = True
in the trainer?
We are using LLama Factory and Unsloth for training DPO on preference data with a length of 11k. During parameter setup, we set bf16 to True and verified that the data type was torch.bfloat16. Previously, we did not encounter any errors when performing SFT training under the same framework. However, after changing all data formats to torch.bfloat16 in the DPO trainer within the framework, we are still experiencing the same error. This might be due to an issue with the framework, but the logic for calling Unsloth in the framework appears to be correct. Therefore, we are reaching out to you for guidance on this matter.
Hmmmmm tbh unsure on how Llama Factory is calling Unsloth - let me check and get back to you, but unsure - I normally suggest for DPO to use Unsloth's notebooks directly to reduce errors
When training the Qlora+unsloth with SFT qwen2-72B-Instruct model, an error occurs with the message "FlashAttention only supports fp16 and bf16 data types." Below is the specific error traceback: [rank0]: File "/llm-align/miniconda3/envs/unsloth/lib/python3.10/site-packages/accelerate/hooks.py", line 169, in new_forward [rank0]: output = module._old_forward(*args, *kwargs) [rank0]: File "/llm-align/unsloth/unsloth/models/llama.py", line 393, in LlamaAttention_fast_forward [rank0]: A = flash_attn_func(Q, K, V, causal = True) [rank0]: File "/llm-align/miniconda3/envs/unsloth/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func [rank0]: return FlashAttnFunc.apply( [rank0]: File "/llm-align/miniconda3/envs/unsloth/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply [rank0]: return super().apply(args, **kwargs) # type: ignore[misc] [rank0]: File "/llm-align/miniconda3/envs/unsloth/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward [rank0]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( [rank0]: File "/llm-align/miniconda3/envs/unsloth/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward [rank0]: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( [rank0]: RuntimeError: FlashAttention only supports fp16 and bf16 data types How can this issue be resolved? Thanks