pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.62k stars 296 forks source link

[RFC] Proximal Policy Optimisation #812

Open SalmanMohammadi opened 2 months ago

SalmanMohammadi commented 2 months ago

Implementing Proximal Policy Optimisation

I've used some of the PyTorch RFC template here for clarity.

Authors:

Summary

I'd like to add support for fine-tuning models using the Proximal Policy Optimisation (PPO) reinforcement learning (RL) algorithm. Similar to Direct Policy Optimisation, PPO is a core component in Reinforcement Learning from Human Feedback (RLHF) for aligning language models.

PPO optimises a language model which acts as a policy with an action space equal to the model's vocabulary, and where the observation space is the distribution over all possible prompts, and the reward is some scalar value indicating the "preference" of the model's completion for a given prompt (the reward is usually given by a reward model calibrated for human preferences).

Motivation

This repository helps make a fascinating technology even more accessible. Supporting PPO will help users to understand and explore LLM alignment techniques in native PyTorch, which is already widely adopted and easy to get started with.

Proposed Implementation

Prior art

TRL implements a generalised PPO trainer. A policy is defined using a thin wrapper around a pre-trained LLM and adds a value function head to be optimised during PPO training. A copy of the model being trained is also initialised and frozen as a reference model.

Feedback and thoughts are very much appreciated. I'm hoping to add value here and I'm grateful for any guidance to help me do so.

kartikayk commented 2 months ago

@SalmanMohammadi thanks so much for the high quality RFC. PPO would be an amazing technique to add to torchtune!

Overall the plan looks good. A quick comment on the model itself:

Integration of reward models into the codebase

The description here would lend very well to the concepts we have in torchtune.

Based on what you described, I'd imagine that you would build a ppo model by adding a custom component builder which keeps most of the architecture the same but replaces the output layer with what you have in mind. Does this generally make sense? Happy to answer more questions on this.

I'd need some more details on the implementation since there's a lot going on here, but I think these would be best communicated in the form of a prototype that does what you had in mind.

I'm also cc-ing @vmoens who's the RL expert in PyTorch for his thoughts and feedback!

SalmanMohammadi commented 2 months ago

Thanks so much for your feedback @kartikayk.

I think it makes sense to start with the reward model implementation. There's a pre-trained reward model for Mistral-7B. Implementing component and model builders for Mistral to start could allow for easy testing. There might need to be some small modifications to convert_weights.py to support loading reward models.

In HuggingFace, reward models inherit the AutoModelForSequenceClassification generic. This is just some sequence model which has a linear classification layer (example for Mistral7B) slapped on top of the final hidden state from the underlying Seq2Seq model.

Writing my thought process below, I wonder if it makes sense to add a TransformerClassifier in transformer.py, with a forward that looks something like:

class TransformerClassifier(nn.Module):
    def __init__(transformer_decoder: TransformerDecoder, embed_dim: int, n_classes: int):
        ...
        self.score = nn.Linear(embed_dim, num_labels)
    ...
    def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
    """
        Args:
            decoder_output (Tensor): TransformerDecoder output with shape [b x s x v]
        Returns:
            Tensor: Preferences/rewards  with shape [b x 1]
    """
        transformer_output = self.transformer_decoder(tokens, input_pos=input_pos)
        ....
        score = self.score(transformer_output)
        # return logits / apply act / etc.

Then, a corresponding component and model would be:

# in _component_builders.py
def mistral_classifier(embed_dim, n_classes, **mistral_args) -> TransformerClassifier:
    transformer_decoder = mistral(mistral_args)
    return TransformerClassifier(transformer_decoder, embed_dim, n_classes)

# in _model_builders.py
def mistral_7b_classifier() -> TransformerClassifier:
    ...

Thank you again for your help and feedback. It's super interesting and fun to be contributing.

Sidenote

It probably wouldn't be much more effort to add support for training reward models once we implement reward models in Torchtune directly. We could probably use the PreferenceDataset that was implemented for DPO. I suppose it's technically a form of fine-tuning, so might be in scope of this library. It'd be really nice to allow users to go through the full RLHF process in native torch.

kartikayk commented 2 months ago

This is awesome, @SalmanMohammadi! Seems like you have a lot of the components figured out!

A few comments from my side:

There might need to be some small modifications to convert_weights.py to support loading reward models

This is great! Currently we have logic in the checkpointer which does the state_dict conversion. My first thought would be that you can just create a branch here for reward models by using the model_type field. I'm not sure how general these might be so maybe we can start with something like MISTRAL_REWARD_MODEL and extend when we add more models? Let me know if that makes sense.

Writing my thought process below

This is perfect! I'd do pretty much exactly this. There might be some small nuances which we catch once you have code, but this looks great to me. You alluded to this, but one thing which would be great to do is to verify correctness by running a random input through yours and some reference implementation and comparing the output tensors,. This should give us confidence in the numerical equivalency and will help other folks use the module with high confidence. Let me know if this makes sense.

It'd be really nice to allow users to go through the full RLHF process in native torch.

100% agreed on this. I'd love collaborate on adding this if you'd be interested. My initial thought here is that this shouldn't be too complicated too add. What do you think?

I also saw you had another question about MLP implementations and why these are copied over :) I think it was a great question. Generally, we've tried to decouple the builders for different models as much as possible. This does lead to copy pasting some code, but generally makes things easy to handle, maintain, extend and ultimately deprecate. If you try to squeeze in too many things into a single implementation, ultimately those become bloated and full of conditionals. This makes any sort of extensions or refactors hard. Over time, we may find opportunities to consolidate and merge code - but thats an easier operation than splitting things to prevent complexity from increasing since this will likely break tons of users. Hope this makes sense. Happy to answer more questions!

SalmanMohammadi commented 2 months ago

I love the support and positivity @kartikayk :)

I've put a PR up for a (hopefully) pretty lightweight and non-invasive TransformerClassifier implementation. I could use some guidance on numerical testing. I'd be happy to also add correctness tests for base mistral, and then the mistral classifier.

I'd love collaborate on adding this if you'd be interested

I think it should be pretty straightforward to add a recipe for training this classifier on a reward modelling task! I'd be happy to hear your thoughts on anything that's missing. I mentioned in the PR that we could start with a recipe using the classifier and the dataset that was implemented for DPO.

Generally, we've tried to decouple the builders for different models as much as possible.

I ended up answering my own question after reading the codebase. It's great to hear your thoughts. There's always a little SWE inside of me complaining about code duplication : ) I think the other advantage of the kind of low-coupling, high-modularity code you mentioned is interpretability. I could easily figure out where the implementation details were for an architecture I was interested in. This is imo a seriously underrated feature of an open-source, popular ML codebase. It makes a huge difference to every level of expertise of user, and particularly users coming from a non-SWE background who want to understand how things work on a more technical level.

Next steps It'd be good to talk more about implementing reward model training. Once we've worked through the TransformerClassifier testing and the PR looks good, I'll hopefully have most of the components I need to implement PPO too. I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointers! On a more general note, I'd also be happy to help write tutorials/documentation on the things we're working on.

kartikayk commented 2 months ago

Awesome, love the PR @SalmanMohammadi! I'll review in a bit, but see that you already have a discussion going!

I could easily figure out where the implementation details were for an architecture I was interested in.

You hit the nail on the head. This was exactly the intent, and I'm glad it resonates. It's one of the design principles we did have much discussion and debate on :)

I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointer

I've been using runpod for my own development and testing. Let me know if this works for you? Of course we'de be happy to do some testing on larger models as well and share all of the learnings and observations with you as well.

This is really exciting! Thanks for helping shape this up. I'm looking forward to sharing this with the community :)

SalmanMohammadi commented 2 months ago

@kartikayk the TransformerClassifier PR is pretty much good to go. Would you still like to collaborate on the RLHF process? There's a lot of steps and I have some design docs I could share on the different components we need. Happy to chat here or on Discord to share some of my draft ideas!

kartikayk commented 2 months ago

@SalmanMohammadi I'm still very interested in the actual training! We can create a sidebar on discord to chat about this so other interested folks can follow along as well. WDYT?

SalmanMohammadi commented 2 months ago

Sounds good! Let me know what you're interested in and I can share my thoughts/updates on what I'm working on. Let's chat more on Discord.

kartikayk commented 2 months ago

Sounds good! Mind sharing your discord handle? :)

SalmanMohammadi commented 2 months ago

It's v3rsachie