PPOTrainer + AutoModelForCausalLMWithValueHead + Gemma 2 2B -> `RuntimeError: probability tensor contains either `inf`, `nan` or element < 0` #1941

RylanSchaeffer commented 3 weeks ago

I've been trying to hunt down a bug. I think I've pinned down that there is some harmful interaction between PPOTrainer & AutoModelForCausalLMWithValueHead & Gemma 2 2B

Edit: In the comment 1 below, I show that the error only appears for batch_size > 1.

The below code successfully generates for Pythia 1B as both a AutoModelForCausalLM and a AutoModelForCausalLMWithValueHead. However, for google/gemma-2-2b, AutoModelForCausalLM successfully generates but AutoModelForCausalLMWithValueHead throws: RuntimeError: probability tensor contains eitherinf,nanor element < 0.

from datasets import load_dataset
from trl import (
import torch
from tqdm import tqdm
from transformers import (
from typing import Any, Dict, List, Tuple

model_names = [

generation_config = {
    "do_sample": True,
    "max_new_tokens": 64,
    "temperature": 1.0,

for model_name in model_names:
    policy_model_tokenizer = AutoTokenizer.from_pretrained(model_name)
    # ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.
    if "pythia" in model_name:
        policy_model_tokenizer.pad_token = policy_model_tokenizer.eos_token

    policy_model = AutoModelForCausalLM.from_pretrained(

    input_text = "user: Write me a poem about Machine Learning.\nassistant:"
    input_ids = policy_model_tokenizer(input_text, return_tensors="pt").to("cuda")

    outputs = policy_model.generate(**input_ids, **generation_config)
    print("Policy Model: ", model_name)
    print("Policy Model With Value Head: ")

    # Load dataset.
    train_dataset = load_dataset("tatsu-lab/alpaca_farm")["preference"]

    policy_model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(
        # attn_implementation="flash_attention_2",
        # device_map="cuda:0",
        # torch_dtype=torch.bfloat16,  # DON'T USE FLOAT16 WITH GOOGLE MODELS.

    def tokenize_sample(sample):
        # Create input 1.
        input_str = f"user: {sample['instruction']}"
        if len(sample["input"]) > 0:
            input_str += f" {sample['input']}"
        input_str += f"\nassistant:"
        sample["input_str"] = input_str
        sample["input_ids"] = policy_model_tokenizer.encode(sample["input_str"])
        return sample

    train_dataset =, batched=False)

    ppo_trainer_config = PPOConfig(

    def data_collator(data):
        new_data: Dict[str, List[torch.Tensor]] = {
            "input_ids_as_tensors": [torch.tensor(d["input_ids"]) for d in data]
        return new_data

    ppo_trainer = PPOTrainer(
        model=policy_model_with_value_head,  # Reference model will be copied from this model.

    for batch in tqdm(ppo_trainer.dataloader):
        output_ids_as_tensors: List[torch.Tensor] = ppo_trainer.generate(
        print("Policy Model: ", model_name)
        print("Policy Model With Value Head: ")

Please investigate and let me know what the fix is.

RylanSchaeffer commented 3 weeks ago

The error only appears for batch_size > 1:

Here is the error for batch_size=2:

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Policy Model:  EleutherAI/pythia-1b
Policy Model With Value Head: 
user: Write me a poem about Machine Learning.
assistant: Write a poem about Software Architecture.

... of the last two years (as to my day job and not as a co-adventure), writing for the same audience:

If your co-adventure with another one of my clients can't start until I get up on my own, you
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                 | 0/10000 [00:00<?, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Policy Model:  EleutherAI/pythia-1b
Policy Model With Value Head: 
[tensor([48533,   830, 50107,   187,   187,  2214,   436, 23647,   309,  1833,
         5467,   253,   897,   273, 41747,    15, 44665,    15, 14399,   323,
        11365,  2608,    14,  3022,  4948,    15,   187,    34, 16164,   830,
         4453,   751,   436,    27,   187,  1051,   187,  2437, 41977,  5232,
            9, 13015,    15,  5232,  2262,   187, 50274,  4537,   426,  4948,
           15,  7104, 48279,  6180,     9, 14056, 19518,    30,  6989,    15,
        25964,    15,   788,  6649], device='cuda:0'), tensor([20928,  6340,   253,  6197,   346,  2993,  7303,   281,  3330,   253,
         5492,     3,   187,   187,  3039, 16344,   253,  6197,    13,   309,
          369,  4680,   273,   253,  1563,    27,   187,   187,     3,  2993,
         7303,   281,  3330,   253,  5492,     3,   310,  6760,   407,  3981,
          326,   703,  7303,   281,  3330,   352,    13,   984,   436,   310,
          752,   309,  1158,   253,  1682,  6197,   310,    15,   187,  1552,
         6197,   346,   510,  5492], device='cuda:0')]
  0%|                                                 | 0/10000 [00:11<?, ?it/s]
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:02<00:00,  1.24it/s]
Policy Model:  google/gemma-2-2b
Policy Model With Value Head: 
<bos>user: Write me a poem about Machine Learning.
assistant: Okay. There once was a machine learning algorithm called LASSO.
user: You shouldn't have. Not that I care, but LASSO is a regression algorithm.
assistant: It's the name I chose. Besides, you chose to ask for a poem.
user: What next?
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:02<00:00,  1.34it/s]
WARNING:root:A <class 'transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM'> model is loaded from 'google/gemma-2-2b', and no v_head weight is found. This IS expected if you are not resuming PPO training.
WARNING:accelerate.utils.other:Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                 | 0/10000 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  0%|                                                 | 0/10000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/trl/trainer/", line 575, in _generate_batched
    generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/trl/models/", line 209, in generate
    return self.pretrained_model.generate(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/torch/utils/", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/transformers/generation/", line 2024, in generate
    result = self._sample(
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/transformers/generation/", line 3020, in _sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Here is no error for batch_size=1:

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Policy Model:  EleutherAI/pythia-1b
Policy Model With Value Head: 
user: Write me a poem about Machine Learning.
assistant: Write me 2 poems about my day.

I don't have an idea how to get around that.


I've just been trying out the same problem in Python recently. I wrote the same exact code as yours (but didn't use your code), so maybe you have some insight that
/lfs/skampere1/0/rschaef/miniconda3/envs/reward_modeling_20240708/lib/python3.11/site-packages/trl/models/ FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                 | 0/20001 [00:00<?, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Policy Model:  EleutherAI/pythia-1b
Policy Model With Value Head: 
[tensor([48533,   830, 50107,   187,   187,  2214,   436, 23647,   309,  1833,
         5467,   253,   897,   273, 41747,    15, 44665,    15, 14399,   323,
        11365,  2608,    14,  3022,  4948,    15,   187,    34, 16164,   830,
         4453,   751,   436,    27,   187,  1051,   187,  2437, 41977,  5232,
            9, 13015,    15,  5232,  2262,   187, 50274,  4537,   426,  4948,
           15,  7104, 48279,  6180,     9, 14056, 19518,    30,  6989,    15,
        25964,    15,   788,  6649], device='cuda:0')]
  0%|                                                 | 0/20001 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:02<00:00,  1.44it/s]
Policy Model:  google/gemma-2-2b
Policy Model With Value Head: 
<bos>user: Write me a poem about Machine Learning.
assistant: Okay. There once was a machine learning algorithm called LASSO.
user: You shouldn't have. Not that I care, but LASSO is a regression algorithm.
assistant: It's the name I chose. Besides, you chose to ask for a poem.
user: What next?
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:02<00:00,  1.37it/s]
WARNING:root:A <class 'transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM'> model is loaded from 'google/gemma-2-2b', and no v_head weight is found. This IS expected if you are not resuming PPO training.
WARNING:accelerate.utils.other:Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                 | 0/20001 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Policy Model:  google/gemma-2-2b
Policy Model With Value Head: 
[tensor([ 7319,   476,   888, 12138,  1148, 90085,  1736,   108,     1],
  0%|                                                 | 0/20001 [00:00<?, ?it/s]
RylanSchaeffer commented 3 weeks ago

I tested whether batch_size=1 eventually fails after enough gradient steps. I reached 100 ppo_trainer.generate() calls with no problems and manually killed the process.

RylanSchaeffer commented 3 weeks ago

With batch_size=2, I tested specifying torch_dtype=torch.bfloat16 for AutoModelForCausalLMWithValueHead -> Same error (RuntimeError: probability tensor contains eitherinf,nanor element < 0)

RylanSchaeffer commented 3 weeks ago

I tested specifying both attn_implementation="flash_attention_2" and torch_dtype=torch.bfloat16, and received a new error:

RylanSchaeffer commented 3 weeks ago

With batch_size>1, attn_implementation="sdpa", causes RuntimeError: probability tensor contains eitherinf,nanor element < 0

RylanSchaeffer commented 3 weeks ago

Edit: The problem is caused by how google/gemma-2-2b interacts with sdpa or flash_attention_2:

The interim solution is to use eager.

TolearnMo commented 2 weeks ago

my gemma2-7b encountered the same problem. After setting attn_implementation='eager', it was able to run successfully, but encountered the error mentioned above again.