pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.86k stars 342 forks source link

[RFC] TransformerDecoder refactor #1017

Closed SalmanMohammadi closed 4 weeks ago

SalmanMohammadi commented 3 months ago

TransformerDecoder Refactor

Authors:

Refactoring TransformerDecoder to offer additional flexibility for new use-cases.

Motivation/Prior art

Currently, TransformerDecoder can only be used for language-modelling tasks. There is interest in additional use-cases, such as:

Such a refactor could allow users to easily adapt a transformer backbone for a variety of down-stream tasks; lm_human_preference_details demonstrates how HF's transformer backbone can be extended in just 8 lines. While this refactor initially targets recipes which will be provided within Torchtune, such as PPO, or sequence-classification training recipes (e.g. for reward models), it would allow users to write custom recipes for many fine-tuning tasks whilst utilising underlying Torchtune features.

Proposed Implementation

A small-scale implementation for Mistral models exists in this draft PR. In summary:

* mistral_classifier should now return an instance ofTransformerClassifier. * Gemma models define a GemmaTransformerDecoder which has a unique output projection, but shares the underlying logic of a TransformerDecoder. We can go two routes here:

1) TransformerLM accepts a Union[nn.Module, Callable[torch.tensor]] (or even just Callable[torch.tensor]) as output. Then, the Gemma component builder is:

tok_embeddings = nn.Embedding(...)
decoder = TransformerDecoder(...)
output = lambda a: F.linear(a, tok_embeddings.weight)
return TransformerLM(decoder, output=lm_output)

2) If that looks a bit clunky/we don't like anonymous functions, we can just define a GemmaTransformerLM which looks like:

class GemmaTransformerLM(nn.Module):
    def __init__(self, decoder: TransformerDecoder) -> None:
        super().__init__()
        self.decoder = decoder

    def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:    
        # shape: [b, s, d]
        h = self.decoder(tokens, input_pos)
        return F.linear(h, self.decoder.tok_embeddings.weight)

Input on how this affects FDSP would be appreciated.

Components in the codebase I estimate will be impacted, and changes necessary, include:

** A note on backwards compatibility

Users who have previous trained models with TransformerDecoders will have checkpoints saved with dict keys in the original format (without the decoder prefix). Am I right in thinking they're going to have issues loading these checkpoints into our new models? This could be a pretty disruptive change - some users will have spent a lot of resources fine-tuning their models.

Could we provide some well-documented deprecation support for converting state dicts until some release version?

pbontrager commented 3 months ago

Some initial comments/opinions after my first pass through the RFC

SalmanMohammadi commented 3 months ago

Thanks for your thoughts @pbontrager!

I don't know if we should treat model heads as wrappers around the decoder, they should be treated more like additional layers. Then you could build your LLM with nn.Sequential(TransformerDecoder(), TransformerLM())

I like how this looks - it makes sense since heads are just additional layers. Some contra points:

For the case of PPO where there are multiple heads, you only want to do a forward pass on the decoder once right? So you either need to make a special head that takes the output from decoder, passes it to a list of heads, and returns all of their outputs, or you need to have the recipe manage the decoder and heads separately. This creates issues for checkpointing though I believe.

Are you suggesting something like this?

class MultiOutputHead():
    def __init__(self, layers: ModuleDict):
        self.layers = layers

    def forward(decoder_output):

        # return dict or tuple or dataclass? 

My thoughts here are similar to above: typing is now abstract, and we're trading some modularity here. Heads can now be composed together arbitrarily, but it might be difficult to deal with different heads using different inputs. I don't actually have a use-case in mind, so perhaps it's not worth trying to over-engineer here. I'm keeping in mind we've also been talking about flexibility in the outputs from TransformerDecoder, so this could just mean we deal with unpacking tuples/dataclasses in each head and assume/check that the decoder has been setup to return necessary states.

RE: Checkpointing - I currently handle checkpointing kinda similar to LoRA checkpointing - the base model with the language head gets saved by default, and value head weights are stored and loaded seperately during training.

For backward compatibility, we might have to add a function in the checkpoint that detects the old version and forces you to run a script that updates the checkpoint. We wouldn't have to maintain it forever since it would force everyone's checkpoints to be updated over a short period of time.

Sounds good to me!


Hope my points make sense, and that I haven't misinterpreted or missed something obvious.

ebsmothers commented 2 months ago

This is an awesome RFC! Overall I like the direction here, just chiming in on a few miscellaneous points raised by both you and @pbontrager.

TransformerDecoder could support returning hidden states from arbitrary layers (or other useful outputs). Some input on how we allow users to specify this would be helpful. We probably just want to return the last hidden state by default.

Is this part of the changes you're proposing here? Or would it be done separately? (If not strictly needed for PPO I'd say it's fine to save for a follow-up, but I know this was one of the things @kartikayk was looking at with this refactor.) If we do want to support it here, I feel like a bool return_hidden_states (or something like that) passed to forward is probably the way to go. Then we return a Union[List[Tensor], Tensor] or something like that. I don't wanna overgeneralize too much and I think this should cover the majority of cases.

mistral_classifier should now return an instance ofTransformerClassifier.

nitpicking but I assume you mean TransformerLM? (Just wanna clarify I'm not misunderstanding)

I think tests for TransformerDecoder should now test TransformerLM - input appreciated.

This makes sense to me.

I don't know if we should treat model heads as wrappers around the decoder, they should be treated more like additional layers. Then you could build your LLM with nn.Sequential(TransformerDecoder(), TransformerLM())

I don't like the nn.Sequential too much tbh. I think the multi-head example is fairly common and we may want to do something like MultiOutputHead (in fact we have such an example in multimodal). I agree checkpoint loading is a bit of a headache in this case but I claim it wouldn't be our job to maintain the mapping for arbitrary state dicts (that would be on the user). But even in that case it is roughly equivalent to how we load LoRA weights -- load some partial state dict with strict=False that needs to adhere to a particular format, then validate that everything lines up.

For now my proposal for a compromise is to type output in TransformerLM as nn.Module instead of nn.Linear. That leaves the door open for other usage, but all our builders and official checkpoint mappings will use nn.Linear and anything beyond that will require the user to make some modifications (btw in general we do not really guarantee state dict key uniqueness anyways, different implementations of existing TransformerDecoder class may or may not have e.g. layers.0.attn.q_proj.lora_a.weight as a state dict key).

I feel the name TransformerLM is misleading as it doesn't actually have a language model. I'd lean towards LangaugeHead

Are you talking about naming of just the head or the full model + head? If the full model I think TransformerLM should be OK.

I like Gemma with the lambda function a lot, that's what I was wanting to do originally. I don't remember if it caused an issue with FSDP though.

Some people are morally opposed to the usage of lambda functions, but I am not one of them. I also think it should be equivalent from the perspective of FSDP (as long as we are not creating a separate set of weights and tying them together it should be fine).

For backward compatibility, we might have to add a function in the checkpoint that detects the old version and forces you to run a script that updates the checkpoint. We wouldn't have to maintain it forever since it would force everyone's checkpoints to be updated over a short period of time.

Another path is to just infer the version and remap during checkpoint load. So here we could add something like

if not any(state_dict[utils.MODEL_KEY].keys().startswith('decoder')):
    convert_v1_to_v2_state_dict() #  <-basically just prepend every non-output param key with "decoder."

A bit more of a black box but imo the checkpointer is kind of that already anyways. Separately we should come up with a proper definition of checkpoint versioning along with support/deprecation model.

SalmanMohammadi commented 2 months ago

Thanks so much for your comments @ebsmothers.

If not strictly needed for PPO I'd say it's fine to save for a follow-up, but I know this was one of the things @kartikayk was looking at with this refactor

I generally agree here. Your suggestion sounds straightforward and if @kartikayk is keen to see it here while I'm making other changes I'm happy to include.

If I gather what you're suggesting, it's that we have the base TransformerLM class which encapsulates a base TransformerDecoder and some arbitrary head. This head could be used for classification, or language modeling, or for multi-head outputs using a MultiOutputHead. Is that right?

I think this abstraction keeps things simple - I'd imagine typing in TransformerLM.forward could be something like Union[torch.Tensor, tuple[torch.Tensor, ...], Dict[str, torch.tensor]].

nitpicking but I assume you mean TransformerLM? (Just wanna clarify I'm not misunderstanding)

The reason I raised this is because not every head results in a language model per-say, but the added complexity of different model types perhaps isn't worth the nomenclature, unless we use something more general like TransformerModel.

But even in that case it is roughly equivalent to how we load LoRA weights -- load some partial state dict with strict=False that needs to adhere to a particular format, then validate that everything lines up.

This makes sense to me.

anything beyond that will require the user to make some modifications I claim it wouldn't be our job to maintain the mapping for arbitrary state dicts (that would be on the user)

I think keeping things as simple and easily-understandable as possible perhaps lends well to making the codebase extensible for user-specific purposes, rather than complex abstractions to minimise code changes, so I'd agree here.

if not any(state_dict[utils.MODEL_KEY].keys().startswith('decoder')): convert_v1_to_v2_state_dict() # <-basically just prepend every non-output param key with "decoder."

I'd lean towards this solution as it feels slightly more user friendly.

ghost commented 2 months ago

Let me express my humble opinion on the subject, as my issue was mentioned in the initial post.

First of all, it seems to me that the issue of backward compatibility should not be essential here, since the project is in the very early stages of development and is not yet very popular. If there is a time to make backward-breaking changes, it is now.

Secondly, this library is being developed as part of the pytorch project, so it seems reasonable to expect it to follow the conventions of its older big brother first, rather than outside projects like huggingface transformers.

Based on the above, doesn't it seem that the TranformerDecoder module fits more as part of the main pytorch library than it does here? If it were LSTM, you wouldn't implement it from scratch here, but import it from there.

ebsmothers commented 2 months ago

@marcinwazny thanks for weighing in here. Re BC-breaking changes I completely agree, it's inevitable that they will occur at this stage in the project. At the same time we want to make sure that we do not leave the users we do have high and dry and without a clear path forward.

Re your suggestion to move TransformerDecoder to PyTorch core, well.. it already exists there 😃. But there are very different contracts in the core PyTorch repo from something like torchtune. For one, you can see from the commit history under torch/nn/modules that there are not a lot of new features going in there (and for a project of PyTorch's scope, this is a good thing imo). One of the primary goals of this project is to facilitate onboarding and experimentation with new state-of-the-art models and techniques easily and quickly. I claim the best way to do this is to actually keep higher-level modeling abstractions here where the barrier to entry is somewhat lower.

pbontrager commented 2 months ago

I think I'm on board with this plan! I just want to +1 again having an option for returning internal state. Also for the checkpoint BC problem, I'm fine with both solutions, the main advantage I see with forcing the user to upgrade their checkpoint is so we don't have to support the old checkpoints for very long.

SalmanMohammadi commented 2 months ago

Thanks @pbontrager! I'm happy with a simple solution like @ebsmothers suggested, using return_hidden_states to return all the hidden states and letting the user deal with whichever ones they like.

I see both sides RE: backwards compatibility. It'll likely only be one major-release worth of BC we'd need to support anyway, so I'd vote for the simplest option with the easiest UX.

SalmanMohammadi commented 2 months ago

If all is well with the plan I'll put up a PR for this at some point soon. thanks for all the feedback : )

SalmanMohammadi commented 4 weeks ago

Closing since this will be addressed by #1224, and mutli-output heads are no longer needed.