Open ebsmothers opened 4 months ago
We've had this discussion before, so I'm sorry if I'm not remembering the response, but why can't we do output=lambda h: F.linear(h, self.tok_embeddings.weight).float())
in the builder instead of adding an entire new class?
@pbontrager I think technically speaking this should be doable too. Personally I kinda like having separate classes for this case though, then we can be very clear on the contract for stuff like FSDP wrapping or checkpointing, where weight tying changes the contract in a meaningful way.
where weight tying changes the contract in a meaningful way.
Can you expand on this point?
where weight tying changes the contract in a meaningful way.
Can you expand on this point?
I just mean that we have an extra param floating around in the untied case that we do not have in the tied case. So e.g. for our memory_efficient_fsdp_wrap
config.. this won't work for tied weight models, and so it is nice to have an explicit class we could gate on there (though admittedly we're only really using this wrapping logic for Llama3 models rn anyways). Code pointer
After the addition of Qwen2, we have multiple models using a transformer decoder class with head weights tied to embedding weights (Gemma and Qwen2). The class
TiedEmbeddingTransformerDecoder
is intended to handle these models and other similar such cases. But there are a couple follow-ups we need to address:(1) Refactor Gemma to use
TiedEmbeddingTransformerDecoder
(2) Add unit tests forTiedEmbeddingTransformerDecoder
, as it is now in our core modules folder without testing.