huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.85k stars 1.25k forks source link

[Trainer] Changing the dataset dynamically during training #2227

Open ilyasoulk opened 1 week ago

ilyasoulk commented 1 week ago

Hello,

I am currently training a model using DPO, and I'm adapting the dataset dynamically during training. My current approach looks like this:

trainer = DPOTrainer(
    model,
    None,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    beta=args.beta,
    max_prompt_length=1024,
    max_length=1536,
)
for i in range(repetitions):
    train_result = trainer.train()
    # Adapt the dataset based on the training result
    dataset = get_adapted_dataset(train_result)
    with PartialState().local_main_process_first():
        # Tokenize the updated dataset
        print("Updating the training dataset")
        trainer.train_dataset = dataset.map(trainer.tokenize_row, num_proc=None)

Is this the correct way to adapt the dataset during training, or is there a more appropriate approach for this scenario?

qgallouedec commented 1 week ago

Using an iterable dataset might be more suited. If the way you update the dataset depends on the results, you'll probably need to set a callback as well