Closed SalmanMohammadi closed 3 months ago
@SalmanMohammadi thanks so much for the high quality RFC. PPO would be an amazing technique to add to torchtune!
Overall the plan looks good. A quick comment on the model itself:
Integration of reward models into the codebase
The description here would lend very well to the concepts we have in torchtune.
component_builders
which stitch together the modules to build the overall architecture. For example, the Llama3 component builders can be found here. This includes the llama3 and lora_llama3 model. llama3
component builder is used to create the llama3_8b
and lama3_70b
model builder here.Based on what you described, I'd imagine that you would build a ppo model by adding a custom component builder which keeps most of the architecture the same but replaces the output layer with what you have in mind. Does this generally make sense? Happy to answer more questions on this.
I'd need some more details on the implementation since there's a lot going on here, but I think these would be best communicated in the form of a prototype that does what you had in mind.
I'm also cc-ing @vmoens who's the RL expert in PyTorch for his thoughts and feedback!
Thanks so much for your feedback @kartikayk.
I think it makes sense to start with the reward model implementation. There's a pre-trained reward model for Mistral-7B. Implementing component and model builders for Mistral to start could allow for easy testing. There might need to be some small modifications to convert_weights.py
to support loading reward models.
In HuggingFace, reward models inherit the AutoModelForSequenceClassification
generic. This is just some sequence model which has a linear classification layer (example for Mistral7B) slapped on top of the final hidden state from the underlying Seq2Seq model.
Writing my thought process below, I wonder if it makes sense to add a TransformerClassifier
in transformer.py
, with a forward
that looks something like:
class TransformerClassifier(nn.Module):
def __init__(transformer_decoder: TransformerDecoder, embed_dim: int, n_classes: int):
...
self.score = nn.Linear(embed_dim, num_labels)
...
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
"""
Args:
decoder_output (Tensor): TransformerDecoder output with shape [b x s x v]
Returns:
Tensor: Preferences/rewards with shape [b x 1]
"""
transformer_output = self.transformer_decoder(tokens, input_pos=input_pos)
....
score = self.score(transformer_output)
# return logits / apply act / etc.
Then, a corresponding component and model would be:
# in _component_builders.py
def mistral_classifier(embed_dim, n_classes, **mistral_args) -> TransformerClassifier:
transformer_decoder = mistral(mistral_args)
return TransformerClassifier(transformer_decoder, embed_dim, n_classes)
# in _model_builders.py
def mistral_7b_classifier() -> TransformerClassifier:
...
Thank you again for your help and feedback. It's super interesting and fun to be contributing.
It probably wouldn't be much more effort to add support for training reward models once we implement reward models in Torchtune directly. We could probably use the PreferenceDataset
that was implemented for DPO. I suppose it's technically a form of fine-tuning, so might be in scope of this library. It'd be really nice to allow users to go through the full RLHF process in native torch.
This is awesome, @SalmanMohammadi! Seems like you have a lot of the components figured out!
A few comments from my side:
There might need to be some small modifications to convert_weights.py to support loading reward models
This is great! Currently we have logic in the checkpointer which does the state_dict conversion. My first thought would be that you can just create a branch here for reward models by using the model_type
field. I'm not sure how general these might be so maybe we can start with something like MISTRAL_REWARD_MODEL
and extend when we add more models? Let me know if that makes sense.
Writing my thought process below
This is perfect! I'd do pretty much exactly this. There might be some small nuances which we catch once you have code, but this looks great to me. You alluded to this, but one thing which would be great to do is to verify correctness by running a random input through yours and some reference implementation and comparing the output tensors,. This should give us confidence in the numerical equivalency and will help other folks use the module with high confidence. Let me know if this makes sense.
It'd be really nice to allow users to go through the full RLHF process in native torch.
100% agreed on this. I'd love collaborate on adding this if you'd be interested. My initial thought here is that this shouldn't be too complicated too add. What do you think?
I also saw you had another question about MLP implementations and why these are copied over :) I think it was a great question. Generally, we've tried to decouple the builders for different models as much as possible. This does lead to copy pasting some code, but generally makes things easy to handle, maintain, extend and ultimately deprecate. If you try to squeeze in too many things into a single implementation, ultimately those become bloated and full of conditionals. This makes any sort of extensions or refactors hard. Over time, we may find opportunities to consolidate and merge code - but thats an easier operation than splitting things to prevent complexity from increasing since this will likely break tons of users. Hope this makes sense. Happy to answer more questions!
I love the support and positivity @kartikayk :)
I've put a PR up for a (hopefully) pretty lightweight and non-invasive TransformerClassifier
implementation. I could use some guidance on numerical testing. I'd be happy to also add correctness tests for base mistral, and then the mistral classifier.
I'd love collaborate on adding this if you'd be interested
I think it should be pretty straightforward to add a recipe for training this classifier on a reward modelling task! I'd be happy to hear your thoughts on anything that's missing. I mentioned in the PR that we could start with a recipe using the classifier and the dataset that was implemented for DPO.
Generally, we've tried to decouple the builders for different models as much as possible.
I ended up answering my own question after reading the codebase. It's great to hear your thoughts. There's always a little SWE inside of me complaining about code duplication : ) I think the other advantage of the kind of low-coupling, high-modularity code you mentioned is interpretability. I could easily figure out where the implementation details were for an architecture I was interested in. This is imo a seriously underrated feature of an open-source, popular ML codebase. It makes a huge difference to every level of expertise of user, and particularly users coming from a non-SWE background who want to understand how things work on a more technical level.
Next steps
It'd be good to talk more about implementing reward model training. Once we've worked through the TransformerClassifier
testing and the PR looks good, I'll hopefully have most of the components I need to implement PPO too. I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointers!
On a more general note, I'd also be happy to help write tutorials/documentation on the things we're working on.
Awesome, love the PR @SalmanMohammadi! I'll review in a bit, but see that you already have a discussion going!
I could easily figure out where the implementation details were for an architecture I was interested in.
You hit the nail on the head. This was exactly the intent, and I'm glad it resonates. It's one of the design principles we did have much discussion and debate on :)
I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointer
I've been using runpod for my own development and testing. Let me know if this works for you? Of course we'de be happy to do some testing on larger models as well and share all of the learnings and observations with you as well.
This is really exciting! Thanks for helping shape this up. I'm looking forward to sharing this with the community :)
@kartikayk the TransformerClassifier PR is pretty much good to go. Would you still like to collaborate on the RLHF process? There's a lot of steps and I have some design docs I could share on the different components we need. Happy to chat here or on Discord to share some of my draft ideas!
@SalmanMohammadi I'm still very interested in the actual training! We can create a sidebar on discord to chat about this so other interested folks can follow along as well. WDYT?
Sounds good! Let me know what you're interested in and I can share my thoughts/updates on what I'm working on. Let's chat more on Discord.
Sounds good! Mind sharing your discord handle? :)
Implementing Proximal Policy Optimisation
I've used some of the PyTorch RFC template here for clarity.
Authors:
Summary
I'd like to add support for fine-tuning models using the Proximal Policy Optimisation (PPO) reinforcement learning (RL) algorithm. Similar to Direct Policy Optimisation, PPO is a core component in Reinforcement Learning from Human Feedback (RLHF) for aligning language models.
PPO optimises a language model which acts as a policy with an action space equal to the model's vocabulary, and where the observation space is the distribution over all possible prompts, and the reward is some scalar value indicating the "preference" of the model's completion for a given prompt (the reward is usually given by a reward model calibrated for human preferences).
Motivation
This repository helps make a fascinating technology even more accessible. Supporting PPO will help users to understand and explore LLM alignment techniques in native PyTorch, which is already widely adopted and easy to get started with.
Proposed Implementation
recipes/
.Prior art
TRL implements a generalised PPO trainer. A policy is defined using a thin wrapper around a pre-trained LLM and adds a value function head to be optimised during PPO training. A copy of the model being trained is also initialised and frozen as a reference model.
Feedback and thoughts are very much appreciated. I'm hoping to add value here and I'm grateful for any guidance to help me do so.