Open ghost opened 6 months ago
Thanks for opening this issue!
This is a good suggestion. We've been discussing redesigning the transformer decoder task to output more information (eg: hidden states from intermediate layers). I think making the embedding layer more generic can be part of this change. I'll put up an RFC some time next week and share here.
Currently, the
forward
method of theTransformerDecoder
class requires a tokens tensor of the shape[b, s]
to be passed as an argument, which is then passed toself.tok_embeddings
.But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.
Perhaps it would be worth relaxing the
TransformerDecoder
implementation to allow easier use of them in such cases?Specifically, to allow the input data to be of any shape
[b, s, ...]
, and to change the type of thetok_embeddings
fromnn.Embedding
to any model that inherits fromnn.Module
and returns a tensor of the shape[b, s, d]
.Alternatively do it like huggingface library, which allows
inputs_embeds
to be passed directly instead ofinputs_ids
.