pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.3k stars 430 forks source link

Input embeddings directly in `TransformerDecoder` #968

Open ghost opened 6 months ago

ghost commented 6 months ago

Currently, the forward method of the TransformerDecoder class requires a tokens tensor of the shape [b, s] to be passed as an argument, which is then passed to self.tok_embeddings.

But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.

Perhaps it would be worth relaxing the TransformerDecoder implementation to allow easier use of them in such cases?

Specifically, to allow the input data to be of any shape [b, s, ...], and to change the type of the tok_embeddings from nn.Embedding to any model that inherits from nn.Module and returns a tensor of the shape [b, s, d].

Alternatively do it like huggingface library, which allows inputs_embeds to be passed directly instead of inputs_ids.

kartikayk commented 6 months ago

Thanks for opening this issue!

This is a good suggestion. We've been discussing redesigning the transformer decoder task to output more information (eg: hidden states from intermediate layers). I think making the embedding layer more generic can be part of this change. I'll put up an RFC some time next week and share here.