CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

`position_ids` error in accelerate PPO trainer #564

Closed pbarragan closed 1 year ago

pbarragan commented 1 year ago

🐛 Describe the bug

Hi,

I'm very new to TRLX, PEFT, and Huggingface, so I'm not sure if I just have some simple configuration wrong, but I am trying to recreate the notebook here originally from this page. The notebook is a bit out of date, so I've slowly been working through various issues. So the things I've done so far are:

I am now running this code (modified slightly from the notebook):

import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import (
    ModelConfig,
    OptimizerConfig,
    PPOConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
    TRLConfig,
)

def get_positive_score(scores):
    "Extract value associated with a positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]

def myconfig():
    return TRLConfig(
        train=TrainConfig(
            seq_length=128, # 1024
            epochs=1000,
            total_steps=10000,
            batch_size=32,
            checkpoint_interval=10000,
            eval_interval=100,
            pipeline="PromptPipeline",
            trainer="AcceleratePPOTrainer",
            save_best=False,
            tracker=None,
        ),
        model=ModelConfig(model_path='facebook/opt-1.3b', num_layers_unfrozen=2,
                          peft_config={"peft_type": "LORA", "r": 8, "lora_alpha": 16, "lora_dropout": 0.05}),

        tokenizer=TokenizerConfig(tokenizer_path='facebook/opt-1.3b', truncation_side="right"),
        optimizer=OptimizerConfig(
            name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
        ),
        scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)),
        method=PPOConfig(
            name="PPOConfig",
            num_rollouts=128,
            chunk_size=128,
            ppo_epochs=4,
            init_kl_coef=0.05,
            target=6,
            horizon=10000,
            gamma=1,
            lam=0.95,
            cliprange=0.2,
            cliprange_value=0.2,
            vf_coef=1,
            scale_reward="ignored",
            ref_mean=None,
            ref_std=None,
            cliprange_reward=10,
            gen_kwargs=dict(
                max_new_tokens=40,
                top_k=0,
                top_p=1.0,
                do_sample=True,
            ),
        ),
    )

def main(hparams={}):
    # Merge sweep config with default config if given
    config = TRLConfig.update(myconfig().to_dict(), hparams)

    if torch.cuda.is_available():
        device = int(os.environ.get("LOCAL_RANK", 0))
    else:
        device = -1

    sentiment_fn = pipeline(
        "sentiment-analysis",
        "lvwerra/distilbert-imdb",
        top_k=2,
        truncation=True,
        batch_size=256,
        device=device,
    )

    def reward_fn(samples: List[str], **kwargs) -> List[float]:
        sentiments = list(map(get_positive_score, sentiment_fn(samples)))
        return sentiments

    imdb = load_dataset("imdb", split="train+test")

    prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

    eval_prompts = [
        "What made this movie so distinctly",
        "This movie was terrible",
        "I was surprised at how this movie",
        "I didn't expect this movie was",
    ] * 16

    trlx.train(
        reward_fn=reward_fn,
        prompts=prompts,
        eval_prompts=eval_prompts,
        config=config,
    )

With this code, I'm still get a few warnings of things I can fix up later, but I run into the following error:

fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
[RANK 0] Starting training
[RANK 0] Collecting rollouts
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 1
----> 1 main({})

Cell In[7], line 118, in main(hparams)
    115 print("Eval prompts set.")
    117 print("Start training.")
--> 118 trlx.train(
    119     reward_fn=reward_fn,
    120     prompts=prompts,
    121     eval_prompts=eval_prompts,
    122     config=config,
    123 )
    124 print("End training.")

File ~/trlx/trlx/trlx.py:129, in train(model_path, reward_fn, dataset, samples, rewards, prompts, eval_prompts, metric_fn, config, stop_sequences)
    126 if config.train.resume_from_checkpoint and os.path.exists(config.train.resume_from_checkpoint):
    127     trainer.load(config.train.resume_from_checkpoint)
--> 129 trainer.learn()
    130 return trainer

File ~/trlx/trlx/trainer/accelerate_base_trainer.py:521, in AccelerateRLTrainer.learn(self)
    516 """
    517 Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
    518 """
    519 logger.info("Starting training")
--> 521 self.prepare_learning()
    522 self.iter_count = 0
    523 self.nth_evaluation = 0

File ~/trlx/trlx/trainer/accelerate_ppo_trainer.py:234, in AcceleratePPOTrainer.prepare_learning(self)
    231 eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)
    232 self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)
--> 234 self.make_experience(self.config.method.num_rollouts)
    236 self.train_dataloader = self.create_train_dataloader()
    238 self.n_inner_epochs = self.config.method.ppo_epochs

File ~/trlx/trlx/trainer/accelerate_ppo_trainer.py:418, in AcceleratePPOTrainer.make_experience(self, num_rollouts, iter_count)
    416 position_ids.masked_fill_(attention_mask == 0, 1)
    417 with torch.no_grad():
--> 418     logits, *_, values = self.model(
    419         all_tokens, attention_mask=attention_mask, position_ids=position_ids
    420     )
    421     # TODO(dahoas): When hydra model works need to also support generation on hydra head
    422     if hasattr(self.model, "frozen_head") or self.model.peft_type:

File /state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/trlx/trlx/models/modeling_ppo.py:329, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, attention_mask, past_key_values, position_ids, head_mask, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, ignore_peft_adapter)
    327         outputs = self.base_model.base_model(**forward_kwargs)
    328 else:
--> 329     outputs = self.base_model(**forward_kwargs)
    331 # TODO: Apply PEFT to value branch
    332 if self.num_value_layers_unfrozen > 0:

File /state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/peft/peft_model.py:442, in PeftModel.forward(self, *args, **kwargs)
    438 def forward(self, *args: Any, **kwargs: Any):
    439     """
    440     Forward pass of the model.
    441     """
--> 442     return self.get_base_model()(*args, **kwargs)

File /state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'position_ids'

I don't think the first line about git is related (it just happens to print out in the same section) although I don't actually know what that problem is either. But the type error I get at the bottom looks very similar to #416, so I thought I'd ask here if anyone has any clues. However, there are some additional interactions with PEFT in my backtrace, so I'm not sure if the problem is actually related to PEFT or something else entirely. I also thought that perhaps I needed different versions of some of these libraries. Anyways, I would appreciate any help. Thanks!

Which trlX version are you using?

Main

Additional system and package information

PEFT version 0.5.0, Transformers 4.33.3

pbarragan commented 1 year ago

As an update, in continuing to poke around to try to see if I could understand what was wrong, I attempted to get the inspiring example from TRLX of the above code to run. Because I'm running on a HPC with no internet access that only has V100s available to it, I had to make some modifications:

This seemed to work except for a few things.

I'm not sure if doing this experiment and seeing it work helps illuminate what might be going on with my original post. I haven't figured that out yet. The goal of this follow up was not to ask a host of new questions but only to provide a different data point that might help clarify what is and is not broken about the previous example.

Note that the only other difference is that I ran this as a python file while the previous post as an attempt using a Jupyter notebook on this machine.

However, when I reintroduce a Lora configuration, I get the same position_ids error again.

peft_config=LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05,)

So I'm trying to investigate what about that is causing the problem and how to fix it.

Thanks in advance for any help.

maxreciprocate commented 1 year ago

Hi @pbarragan! Sorry for the late response. The position_ids bug in question has been resolved, however I can't reproduce your hanging behaviour in my environment (which has access to the Internet). You can still disable checkpointing, if you suspect that's one possible reason for the hang, by setting checkpoint_interval to 99999999 and save_best to False. The warning you observed is expected in this example, but also it is innocuous because of the tiny size of the reward model used.

pbarragan commented 1 year ago

@maxreciprocate , thank you so much for the reply and the fix in #566! Unfortunately, I need the checkpointing as part of my workflow, so I'm going to keep poking around to see what might be causing it. I'm still guessing it has something to do with file locking being disabled in some of the directories in the server I have to use for this, but I don't have enough details around this to ask a sensible question. Thank you so much! I'll try to give this a shot as soon as possible and reopen if something is still not working related to the position_ids bug. Take care!