huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.28k stars 1.16k forks source link

Incorrect reference responses when using PEFT with PPOTrainer #1871

Open Sean-OB opened 1 month ago

Sean-OB commented 1 month ago

Below is a snippet from ppo_trainer.py

Line permalink

if generate_ref_response:
            ref_model = self.model if self.is_peft_model else self.ref_model
        if isinstance(query_tensor, List):
            response = self._generate_batched(
                self.model,
                query_tensor,
                length_sampler=length_sampler,
                batch_size=batch_size,
                return_prompt=return_prompt,
                **generation_kwargs,
            )
            if generate_ref_response:
                ref_response = self._generate_batched(
                    ref_model,
                    query_tensor,
                    length_sampler=length_sampler,
                    batch_size=batch_size,
                    return_prompt=return_prompt,
                    **generation_kwargs,
                )

When training with PEFT, we have ref_model the same as the base model but instead called with a context to disable the adapters:

with torch.no_grad():
            all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
                self.model,
                queries,
                responses,
                model_inputs,
                response_masks=response_masks,
                return_logits=full_kl_penalty,
            )
            with self.optional_peft_ctx():
                ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                    self.model if self.is_peft_model else self.ref_model,
                    queries,
                    responses,
                    model_inputs,
                    return_logits=full_kl_penalty,
                )

However, code to generate reference responses doesn't use this context. This leads to the reference responses logged in the tables to come from the optimized RL model rather than the reference model.

To reproduce, run any training loop with the PPOTrainer with your logging software of choice -- my setup uses WandB -- and look at the table of responses. The reference responses will be drawn from the same distribution as the model responses. Below is a screenshot from a dummy run where I rewarded the model for outputting the word "but." The reference responses should not be any different after the loop.

image
skylooop commented 1 month ago

I observed same problem with DPOTrainer. generate_during_eval=True in DPOConfig produces reference outputs from current model being trained.

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.