pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.39k stars 448 forks source link

TiedEmbeddingTransformerDecoder: unit test and Gemma refactor #1241

Open ebsmothers opened 4 months ago

ebsmothers commented 4 months ago

After the addition of Qwen2, we have multiple models using a transformer decoder class with head weights tied to embedding weights (Gemma and Qwen2). The class TiedEmbeddingTransformerDecoder is intended to handle these models and other similar such cases. But there are a couple follow-ups we need to address:

(1) Refactor Gemma to use TiedEmbeddingTransformerDecoder (2) Add unit tests for TiedEmbeddingTransformerDecoder, as it is now in our core modules folder without testing.

pbontrager commented 4 months ago

We've had this discussion before, so I'm sorry if I'm not remembering the response, but why can't we do output=lambda h: F.linear(h, self.tok_embeddings.weight).float()) in the builder instead of adding an entire new class?

ebsmothers commented 4 months ago

@pbontrager I think technically speaking this should be doable too. Personally I kinda like having separate classes for this case though, then we can be very clear on the contract for stuff like FSDP wrapping or checkpointing, where weight tying changes the contract in a meaningful way.

joecummings commented 4 months ago

where weight tying changes the contract in a meaningful way.

Can you expand on this point?

ebsmothers commented 4 months ago

where weight tying changes the contract in a meaningful way.

Can you expand on this point?

I just mean that we have an extra param floating around in the untied case that we do not have in the tied case. So e.g. for our memory_efficient_fsdp_wrap config.. this won't work for tied weight models, and so it is nice to have an explicit class we could gate on there (though admittedly we're only really using this wrapping logic for Llama3 models rn anyways). Code pointer