facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.31k stars 6.39k forks source link

Position embedding does not support one-hot input #4460

Open KevinZhoutianyi opened 2 years ago

KevinZhoutianyi commented 2 years ago

❓ Questions and Help

Hi all,

I am using the pre-trained transformer for my project. I follow 'reproduce ende-wmt14' to train the tranformer and want to do further improvement on that pre-trained model. For some reason, I need to use one-hot encoded sentences as input (cannot use argmax to convert that one-hot vector).

However, it seems that the input for the position embedding in the transformer is expected to be of size [bsz x seqlen], but my input is [batchsize seqlen vocabsize].

In this case, should I modify the basic transformer model, register a new one and train on wmt14 again? Is there any tutorial for changing current architecture model and register as a new one, or is there any convenient way to change the position embedding?

Any help would be appreciated!

gmryu commented 2 years ago

Values in the(bsz x seqlen) Input are exact indexes in the vocabs. To turn them into one-hot vectors, you first create (bsz x seqlen x vocabsize) zero vectors and assign vectors[bsz,seqlen, the given indexes]=1

I am not well-informed about efficiently utilizing tensors, However I believe actual (bsz x seqlen x vocabsize) one-hot vectors waste too much gpu memories and knowing the indexes alone should be sufficient for your further implementaion.

Wish your progress.

KevinZhoutianyi commented 2 years ago

Thanks for your reply, gmryu! I am doing an end2end back translation project, so I need to feed the output of the first model to the second model. With some tricks, I convert the output of the first model to one-hot vector and want to use that one-hot vector as the input for the second model. As it's an end2end training, I cannot use argmax to convert one-hot to index, so I'm wondering whether it's possible to pass one-hot vector as the input of the fairseq model.

gmryu commented 2 years ago

Sorry for misunderstanding your intends.

In that case, you have to write a new model and it also comes down to how you train a model with the other. Do you write a new train.py?

--

I saw someone said using linear ReLU instead of argmax so you still have some gradients?

KevinZhoutianyi commented 2 years ago

Yes! I wrote a new trainer and used the generator=task.build_generator([fqmodel], cfg.generation) and the checkpoint model to do customized training. But, as you said, ig I need to set up a new model. Do you know any tutorial about modifying exsisting model? I really dont want to set up the model from scratch because of just one tiny change(position encoding) in the model.

gmryu commented 2 years ago

Copy and paste the transformer or the model you use. That is the most, if not all, tutorials you may get publicly and officially. Actually modifying a model is easy as long as you are not afraid of auto gradients.

If it is only the pos embeds, you may copy the encoder and decoder, adjust their __init__, forward_embedding, forward to support a new PosEmbedding class or deal with the one-hot inside forward before summing up embeds. Well, there is nothing more I can say at this stage. It really depends on how you want to handle one-hot vectors. Feel free to ask more detailed questions and I will answer as long as I can help.

KevinZhoutianyi commented 2 years ago

Thank you so much gmryu, you are so nice! I'll try what you said!