huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.13k stars 1.28k forks source link

PPOTrainer with HuggingFace PreTrainedModelWrapper Models #2357

Open Mrinh212375 opened 1 week ago

Mrinh212375 commented 1 week ago

System Info

Kaggle Notebook

Information

Tasks

Reproduction

import trl
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, BitsAndBytesConfig,AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead,AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl import create_reference_model
from datasets import Dataset, load_dataset
import numpy as np
import json
import random
import pandas as pd
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    #device_map="auto",
    torch_dtype=torch.bfloat16,
)
tokenizer.pad_token = tokenizer.eos_token
ref_model = create_reference_model(model)
rm_model = AutoModelForSequenceClassification.from_pretrained('/kaggle/input/reward_model/transformers/default/1/reward_model_50epochs/content/reward_model_50epochs')
dataset = load_dataset("argilla/reward-model-data-falcon")
def create_dataset(row):
  if row['choose-best']['value'][0]==2:
    row['response-1'],row['response-2']=row['response-2'],row['response-1']
  return row

dataset=dataset.map(lambda x: create_dataset(x))
dataset=dataset['train']
dataset
dataset = dataset.rename_columns({
    'instruction': 'prompt',
    'response-1': 'response',
    'response-2': 'rejected',
})

# Remove specified columns
dataset = dataset.remove_columns(['choose-best', 'external_id','rejected'])

# Display the dataset
dataset
def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

ppo_training_dataset = ppo_training_dataset.map(tokenize_function, batched=False)
config = PPOConfig(
    #model_name="google/gemma-2-2b-it",
    learning_rate=1.41e-5,
    mini_batch_size=5,
    batch_size=20,
    output_dir='/kaggle/working/'
)

ppo_trainer = PPOTrainer(config=config,
                         processing_class = 'PreTrainedTokenizerBase' ,
                         policy = model,
                         ref_policy = ref_model,
                         reward_model = rm_model,
                         #tokenizer=tokenizer,
                         train_dataset=ppo_training_dataset,
                         data_collator=collator)

outputs:

Traceback (most recent call last):
  Cell In[83], line 9
      1 config = PPOConfig(
      2     #model_name="google/gemma-2-2b-it",
      3     learning_rate=1.41e-5,
   (...)
      6     output_dir='/kaggle/working/'
      7 )
----> 9 ppo_trainer = PPOTrainer(config=config,
     10                          processing_class = 'PreTrainedTokenizerBase' ,
     11                          policy = model,
     12                          ref_policy = ref_model,
     13                          reward_model = rm_model,
     14                          #tokenizer=tokenizer,
     15                          train_dataset=ppo_training_dataset,
     16                          data_collator=collator)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:176, in PPOTrainer.__init__(self, config, processing_class, policy, ref_policy, reward_model, train_dataset, value_model, data_collator, eval_dataset, optimizers, callbacks)
    172 #########
    173 # setup model, optimizer, and others
    174 #########
    175 for module in [policy, ref_policy, value_model, reward_model]:
--> 176     disable_dropout_in_model(module)
    177 if args.stop_token and args.stop_token == "eos":
    178     args.stop_token_id = processing_class.eos_token_id

File /opt/conda/lib/python3.10/site-packages/trl/trainer/utils.py:788, in disable_dropout_in_model(model)
    787 def disable_dropout_in_model(model: torch.nn.Module) -> None:
--> 788     for module in model.modules():
    789         if isinstance(module, torch.nn.Dropout):
    790             module.p = 0

AttributeError: 'NoneType' object has no attribute 'modules'

Expected behavior

https://github.com/huggingface/trl/blob/v0.12.1/trl/trainer/ppo_trainer.py#L91 In PPOTrainer class, how to pass nn.modules if I'm working with a HF PreTrainedModelWrapper model......Is there any way to extract nn.module from the PreTrainedModelWrapper ?

Checklist

ccs96307 commented 1 week ago

Hi, it looks like the error arises because the PPOTrainer class expects a value_model to be defined and passed in, which appears to be required in the current TRL version. The disable_dropout_in_model method is likely encountering NoneType because value_model wasn’t specified, and thus defaults to None.

Hope this helps!

Mrinh212375 commented 1 week ago

Hi, it looks like the error arises because the PPOTrainer class expects a value_model to be defined and passed in, which appears to be required in the current TRL version. The disable_dropout_in_model method is likely encountering NoneType because value_model wasn’t specified, and thus defaults to None.

Hope this helps!

Hi, thanks......I have passed value_model the same as policy_model, I thought it was optional, so didn't pass anything.........anyway error is gone. also I can call the ppo_trainer.train() method directly right ? unlike the older version, no need to write ppo training loop.....Can you please clarify on this point.

ccs96307 commented 1 week ago

Glad to hear the error is resolved!

Yes, as far as I know, you can directly call the ppo_trainer.train() method without needing to write a training loop.