huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.56k stars 1.2k forks source link

error when using PPO in Gemma #1663

Closed mostafamdy closed 4 months ago

mostafamdy commented 5 months ago

System Info

Hi, I tried using ppo with gemma model but I get this error I think the issue is here is_encoder_decoder

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[26], line 68
     66 print(response_tensors)
     67 #### Run PPO step
---> 68 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
     69 ppo_trainer.log_stats(stats, batch, rewards)
     70 break

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:721, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    718 full_kl_penalty = self.config.kl_penalty == "full"
    720 with torch.no_grad():
--> 721     all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
    722         self.model,
    723         queries,
    724         responses,
    725         model_inputs,
    726         response_masks=response_masks,
    727         return_logits=full_kl_penalty,
    728     )
    729     with self.optional_peft_ctx():
    730         ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
    731             self.model if self.is_peft_model else self.ref_model,
    732             queries,
   (...)
    735             return_logits=full_kl_penalty,
    736         )

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:994, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    992 if response_masks is not None:
    993     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 994 logits, _, values = model(**input_kwargs)
    996 if self.is_encoder_decoder:
    997     input_ids = input_kwargs["decoder_input_ids"]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1568, in Module._call_impl(self, *args, **kwargs)
   1565     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1566     args = bw_hook.setup_input_hook(args)
-> 1568 result = forward_call(*args, **kwargs)
   1569 if _global_forward_hooks or self._forward_hooks:
   1570     for hook_id, hook in (
   1571         *_global_forward_hooks.items(),
   1572         *self._forward_hooks.items(),
   1573     ):
   1574         # mark that always called hook is run

File /opt/conda/lib/python3.10/site-packages/trl/models/modeling_value_head.py:171, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    168 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
    169     kwargs.pop("past_key_values")
--> 171 base_model_output = self.pretrained_model(
    172     input_ids=input_ids,
    173     attention_mask=attention_mask,
    174     **kwargs,
    175 )
    177 last_hidden_state = base_model_output.hidden_states[-1]
    178 lm_logits = base_model_output.logits

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1326, in PeftModelForSeq2SeqLM.forward(self, input_ids, attention_mask, inputs_embeds, decoder_input_ids, decoder_attention_mask, decoder_inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1324     with self._enable_peft_forward_hooks(**kwargs):
   1325         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1326         return self.base_model(
   1327             input_ids=input_ids,
   1328             attention_mask=attention_mask,
   1329             inputs_embeds=inputs_embeds,
   1330             decoder_input_ids=decoder_input_ids,
   1331             decoder_attention_mask=decoder_attention_mask,
   1332             decoder_inputs_embeds=decoder_inputs_embeds,
   1333             labels=labels,
   1334             output_attentions=output_attentions,
   1335             output_hidden_states=output_hidden_states,
   1336             return_dict=return_dict,
   1337             **kwargs,
   1338         )
   1340 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1341 if decoder_attention_mask is not None:
   1342     # concat prompt attention mask

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids'

Who can help?

No response

Information

Tasks

Reproduction

.

Expected behavior

.

RUFFY-369 commented 5 months ago

@mostafamdy I also think the same, **input_kwargs can't have decoder_input_ids without self.is_encoder_decoder being True. That being said, self.is_encoder_decoder gets its value from model's config. And there is no attribute called is_encoder_decoder in Gemma's configuration file : https://github.com/huggingface/transformers/blob/c681b58b06f6fb8b5c331f380548af3b4b33f881/src/transformers/models/gemma/configuration_gemma.py#L27 Can you debug by getting ppo_trainer.is_encoder_decoder? Or maybe just set it to false and check if the error is gone.

mostafamdy commented 5 months ago

Do you know how we can set it to False without changing the source code?

amyeroberts commented 5 months ago

cc @ArthurZucker

RUFFY-369 commented 5 months ago

@mostafamdy here is the code to get a PPOTrainer instance which you may have used as I don't know about your script:

access_token = 'to_fill'
model = AutoModelForCausalLMWithValueHead.from_pretrained('google/gemma-2b',token = access_token
)
#the config dict doesn't have 'is_encoder_decoder' attribute
print("config", model.pretrained_model.config)

model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token = access_token)
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

So, AutoModelForCausalLMWithValueHead takes google/gemma-2b or google/gemma-7b as pretrained model name for PPOTrainer. AutoModelForCausalLMWithValueHead class itself doesn't have is_encoder_decoder attribute when the google/gemma-2b config is accessed here for ValueHead class with gemma-2b config.

The TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids' which you get can be verified [here in this code line] (https://github.com/huggingface/transformers/blob/4ad5adaf1d224fa28ffa8e1d124846b1d55a5d0e/src/transformers/models/gemma/modeling_gemma.py#L1073). So, clearly we need to set is_encoder_decoder in PPOTrainer asFalse as somehow it is set True which leads to sending decoder_input_ids in input_ids to GemmaForCausalLM.forward().

So, try this simple line of code for changing the value of self.is_encoder_decoder in PPOTrainer:

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer.is_encoder_decoder = False
mostafamdy commented 5 months ago

Thanks @RUFFY-369 I tried this but not working with me

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer.is_encoder_decoder = False

it worked after adding this code.

# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

here is the full code

from transformers import AutoModelForCausalLM, GemmaTokenizer

model = AutoModelForCausalLM.from_pretrained(    
    config.model_name,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,  # Loading weights in 4-bit format
        bnb_4bit_quant_type="nf4",  # Using non-linear quantization with 4 bits
        bnb_4bit_compute_dtype=torch.bfloat16,  # Using bfloat16 for computation
        bnb_4bit_use_double_quant=True  # Using double quantization
    ),
    trust_remote_code=True
)

# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

tokenizer = GemmaTokenizer.from_pretrained(config.model_name)

# tokens = tokenizer("Hi How are", return_tensors='pt')
# outputs = model(**tokens)

peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

model = get_peft_model(model, peft_config)

print("Using lora")
print(model.print_trainable_parameters())
# print(tokens.keys())

# outputs = model.generate(**tokens)
# outputs = model(**tokens)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model,                                                               
                                                           torch_dtype=torch.bfloat16,
                                                           is_trainable=True)
# outputs = model(**tokens)
print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(model)}\n')
print(model.v_head)

Have a nice day 😄

mostafamdy commented 5 months ago

I don't know is it correct or not

I found this in test ppo trainer

RUFFY-369 commented 5 months ago

@mostafamdy Yeah I checked that test script out while finding the value changes in is_encoder_decoder and decoder_input_ids as I didn't have your script. Also, apologies I couldn't test out the code I told you as cuda was running out of memory on my system maybe 'cause another model was in training phase.

So, the above code that you mentioned, are you using all the code from test file or just bits of it to make your script work?

mostafamdy commented 5 months ago

Thank you so much for your help no I used only this part of code

# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
RUFFY-369 commented 5 months ago

Your welcome ! Glad I could be of help. 😄 Oh! okay if that worked for you as one of the uses of test files are that they have at most of the times solutions to different general errors too.

Have a nice day :+1: 😄

ArthurZucker commented 4 months ago

Hey both, is the issue that the newly resized embedding don't require grad even if the rest does?

RUFFY-369 commented 4 months ago

Hi @ArthurZucker , what i found out was that there was a PR with DPO+Gradient checkpoiting issue where if "one uses gradient_checkpointing we need to attach hooks to enable inputs to have requires grad to true, otherwise the training will either silently fail or completely fail". And the fix was as such:

elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

It's the same as what @mostafamdy found in test_ppo_trainer. But the main file of ppo_trainer doesn't have this fix regarding gradient_checkpointing as compared to other files such as dpo_trainer. That's why the ppo_trainer() instance leads to this error and get fixed by the same block of code in the test file and dpo_trainer.py file

ArthurZucker commented 4 months ago

cc @younesbelkada 🤗

younesbelkada commented 4 months ago

This is a TRL issue so transferring it here !

younesbelkada commented 4 months ago

https://github.com/huggingface/trl/pull/1664 should fix the issue