rank0: Traceback (most recent call last):
rank0: File "/opt/tmp/nlp/wzh/LLM-Dojo/rlhf/rloo_train.py", line 167, in
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/trl/trainer/rloo_trainer.py", line 246, in train
rank0: query_response, logits = generate(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/trl/trainer/utils.py", line 1102, in generate
rank0: output = lm_backbone.generate(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/peft/peft_model.py", line 1491, in generate
rank0: outputs = self.base_model.generate(*args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
rank0: return func(*args, *kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/transformers/generation/utils.py", line 1758, in generate
rank0: result = self._sample(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/transformers/generation/utils.py", line 2397, in _sample
rank0: outputs = self(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
rank0: return self._call_impl(args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
rank0: return forward_call(*args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 855, in _CausalLM_fast_forward
rank0: outputs = self.model(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
rank0: return self._call_impl(*args, *kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
rank0: return forward_call(args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 710, in LlamaModel_fast_forward
rank0: layer_outputs = torch.utils.checkpoint.checkpoint(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner
rank0: return torch._dynamo.disable(fn, recursive)(*args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
rank0: return fn(*args, *kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
rank0: return fn(args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint
rank0: return CheckpointFunction.apply(function, preserve, args)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/autograd/function.py", line 598, in apply
rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 262, in forward
rank0: outputs = run_function(args)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 706, in custom_forward
rank0: return module(inputs, past_key_value, output_attentions, padding_mask = padding_mask)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
rank0: return self._call_impl(*args, *kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
rank0: return forward_call(args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 453, in LlamaDecoderLayer_fast_forward
rank0: hidden_states, self_attn_weights, present_key_value = self.self_attn(
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
rank0: return self._call_impl(*args, kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
rank0: return forward_call(*args, *kwargs)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/models/llama.py", line 343, in LlamaAttention_fast_forward
rank0: Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/kernels/rope_embedding.py", line 178, in inplace_rope_embedding
rank0: Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
rank0: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/torch/autograd/function.py", line 598, in apply
rank0: return super().apply(args, kwargs) # type: ignoremisc: File "/home/nlp/miniconda3/envs/codellm2/lib/python3.9/site-packages/unsloth/kernels/rope_embedding.py", line 154, in forward
rank0: Q *= cos
rank0: RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 1
When I used RLOOTrainer in the trl library for rlhf, I loaded the policy model and ref_policy model through unsloth, but it reported the above error, so I would like to ask if it is not supported?
When I used RLOOTrainer in the trl library for rlhf, I loaded the policy model and ref_policy model through unsloth, but it reported the above error, so I would like to ask if it is not supported?