Closed kwen2501 closed 1 month ago
I agree about the point that the input signature semantics are questionable. Can we improve that part?
tensor
and leave it at thatOtoh, i don't see this as a bad/invasive change. I think we can better document it, but it has the net effect of making the code easier to work with for PP and no other issues.
If we were to define a model from scratch, then yes, we can define it in whatever way that fits us.
But if we were given a model and to serve such model, then it is questionable whether we should recommend this way to the audience.
It would also fail to work once the layer signature differs from the Transformer signature in the number of arguments they take.
I think of this as an opportunity to show off how it can be relatively simple to use the Manual PP frontend with a certain way of writing the model code.
Also, I think it is in line with the goals of torchtitan not to be too opinionated, but to show off the native PT-D tech. In this case, let's show off how both frontends can be applied to the model easily, and with a very consistent UX between them. (there are rough edges still, but once we fix them it'd be cool to show that the same model can be pretty easily fed into either frontend.
https://github.com/pytorch/torchtitan/blob/f2c3a114e9737d8eca4078486b81f3529861e1b3/torchtitan/models/llama/model.py#L421-L433
if-else
in linemay seem a bit invasive a code change to the original model, in that it changes the model's semantics.
tokens
are "Input token indices" as per the signature description. It should go through an embedding (self.tok_embeddings
) to expand into feature values -- a space expansion, then get processed by the Transformer layers in those feature spaces. Passingtokens
directly to transformer layers is an unclear semantic.tokens
are generallyint64
so theTransformer
module's signature may be like:whereas the
h
values are usuallybfloat16
orfloat32
.Dtype annotated signatures are not uncommon, see for example the GPT2 model in Transformer. (In fact, it seems most models use detailed types.) https://github.com/huggingface/transformers/blob/481a95781404e48b1c80940be17e8279dec82fe8/src/transformers/models/gpt2/modeling_gpt2.py#L975-L990