unslothai / unsloth

Finetune Llama 3, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
12.35k stars 799 forks source link

Does it support rloo_trainer of trl? #725

Open mst272 opened 3 days ago

mst272 commented 3 days ago

rank0: Traceback (most recent call last): rank0: File "/opt/tmp/nlp/wzh/LLM-Dojo/rlhf/rloo_train.py", line 167, in

rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/trl/trainer/rloo_trainer.py", line 246, in train rank0: query_response, logits = generate( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/trl/trainer/utils.py", line 1102, in generate rank0: output = lm_backbone.generate( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/peft/peft_model.py", line 1491, in generate rank0: outputs = self.base_model.generate(*args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank0: return func(*args, *kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/transformers/generation/utils.py", line 1758, in generate rank0: result = self._sample( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/transformers/generation/utils.py", line 2397, in _sample rank0: outputs = self( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl rank0: return self._call_impl(args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl rank0: return forward_call(*args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 855, in _CausalLM_fast_forward rank0: outputs = self.model( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl rank0: return self._call_impl(*args, *kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl rank0: return forward_call(args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 710, in LlamaModel_fast_forward rank0: layer_outputs = torch.utils.checkpoint.checkpoint( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner rank0: return torch._dynamo.disable(fn, recursive)(*args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn rank0: return fn(*args, *kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner rank0: return fn(args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint rank0: return CheckpointFunction.apply(function, preserve, args) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/autograd/function.py", line 598, in apply rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 262, in forward rank0: outputs = run_function(args) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 706, in custom_forward rank0: return module(inputs, past_key_value, output_attentions, padding_mask = padding_mask) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl rank0: return self._call_impl(*args, *kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl rank0: return forward_call(args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 453, in LlamaDecoderLayer_fast_forward rank0: hidden_states, self_attn_weights, present_key_value = self.self_attn( rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl rank0: return self._call_impl(*args, kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl rank0: return forward_call(*args, *kwargs) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 343, in LlamaAttention_fast_forward rank0: Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/kernels/rope_embedding.py", line 178, in inplace_rope_embedding rank0: Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids) rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/autograd/function.py", line 598, in apply rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/kernels/rope_embedding.py", line 154, in forward rank0: Q *= cos rank0: RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 1

When I used RLOOTrainer in the trl library for rlhf, I loaded the policy model and ref_policy model through unsloth, but it reported the above error, so I would like to ask if it is not supported?

mst272 commented 3 days ago

And I tried the same thing on the DPOV2Trainer and got the same error. But the script runs when i do not use unsloth.

danielhanchen commented 2 days ago

Hmmm its probably because these trainers need generation steps - hmm I'll have to see