huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.23k stars 1.3k forks source link

precompute_ref_log_probs not working correctly? #2423

Open dakru012 opened 15 hours ago

dakru012 commented 15 hours ago

System Info

Information

Tasks

Reproduction

# basic example dataset
train_dataset = Dataset.from_dict({
    "chosen": [
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
    ],
    "rejected": [
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
    ],
})

train_dataset = train_dataset.map(extract_prompt)

training_args = DPOConfig(
    output_dir='DPO_output',
    logging_steps=10,
    loss_type='sigmoid',
    bf16=True,
    precompute_ref_log_probs=True,
)

trainer = DPOTrainer(
    model=model,
    ref_model= ref,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
)

trainer.train()

Expected behavior

Hi, I have some questions about a potential issue or misunderstanding on my side. The point of precompute_ref_log_probs is to calculate the ref log probabilities for the whole dataset before the actual training process, and then later during training we can just load the precomputed probabilities while saving the GPU memory space for the ref model, right? However, it seems like the precomputed log probabilities are never actually loaded.

In the corresponding part in get_batch_loss_metrics():

def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        model_output = self.concatenated_forward(model, batch)

        # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model
        if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
            #LOADED
            ref_chosen_logps = batch["ref_chosen_logps"]
            ref_rejected_logps = batch["ref_rejected_logps"]
        else:
            #COMPUTED AGAIN
            ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
...

The if condition is never true, even if the log probabilities were computed, resulting in unnecessary computations for the ref model. This is because the PreferenceCollator does not include the ref_chosen_logps and ref_rejected_logps in the batch.

I made some changes to the Collator to include those, but first I wanted to make sure that I understood the precompute_ref_log_probs argument correctly.

Checklist

qgallouedec commented 8 hours ago

That's a good catch, thanks @dakru012! Do you want to submit a PR to fix it?

SwayamInSync commented 8 hours ago

That's a good catch, thanks @dakru012! Do you want to submit a PR to fix it?

I think these lines within concatenated_forward are the culprit, names should be [ref_chosen_logps, ref_rejected_logps] instead of [chosen_logps, rejected_logps] then need to handle the same case at compute_ref_log_probs function

        output["chosen_logps"] = all_logps[:num_examples]
        output["rejected_logps"] = all_logps[num_examples:]

Let me know if the PR is there otherwise I can include the relevant fixes inside #2426 or made a new one

dakru012 commented 7 hours ago

@SwayamInSync I don't think that's the problem. I will take a look at it again and do a PR, but it is already midnight here so I gotta sleep first 😴