prajdabre / yanmtt

Yet Another Neural Machine Translation Toolkit
MIT License
174 stars 32 forks source link

Add xProphetNet model into this toolkit #19

Open koukoulala opened 2 years ago

koukoulala commented 2 years ago

Hi, very helpful toolkit, I have learned a lot from it.

Recently, I have been focused on the multi-lingual title generation related tasks, and found that xProphetNet model has good performance, especially in the XGULE benchmarks. I wanted to distill a small xProphetNet model and pre-train on my own dataset. However, I did not find the relevant pre-training codes, so I would like to ask if you would consider adding the pre-training of xProphetNet model. I can provide the code of model architecture and fine-tuning (which can reproduce the results) process.

For possible doubts, I considered using the mBART model, but the tokenizer of the pre-trained mBART model is language-specific, and my own dataset cannot be language-specific for training. I considered putting all the data in the same file to generate a unified Tokenizer, but I was concerned that a relative reduction in the vocab_size might affect the model's effectiveness. Do you have any suggestions?

Thanks

prajdabre commented 2 years ago

Hi,

Since xprophetnet is an enc-dec model, this should be fairly straightforward. Firstly, use the dev branch which will soon be merged with the main branch.

If you wish to pre-train your own prophetnet, you can use my code as it is but your main challenge will be modifying the loss function.

  1. In common_utils.py go to this function: def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None) and extend it to def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, prophet_net=False, prophet_net_ngrams=4)
  2. Inside the modified method: wrap the existing code in a for loop which runs from indices i=0 to prophet_net_ngrams-1. The value of i indicates how much the labels should be left shifted by. When doing left shifting, ensure that the right side is padded with the index corresponding to the pad token. After shifting, the remainder of the loss computation code should be the same except that the loss values are added/averaged.
  3. Ensure that for pretrain_nmt.py, the appropriate flags are created and passed to the lines wherever label_smoothed_nll_loss is called. Also ensure that the prophetnet tokenizers are created appropriately at the very beginning of the main function.

If you want to do pre-training EXACTLY like the prophetnet paper, then you will have to mess with the generate_batches_monolingual_masked method in common_utils.py where you have to do tokenization exactly how it is done in the paper. If you want to pre-fine tune an existing xprophetnet on your own monolingual data, then modifying generate_batches_monolingual_masked is very important.

If you wish to directly fine-tune xprophetnet then you will need to do the following:

  1. Modify the generate_batches_bilingual, generate_batches_bilingual and generate_batches_for_decoding methods to handle tokenization the way its meant to be. You will see that I have if-elses dealing with mbart-50 and bart.
  2. Ensure that for train_nmt.py, and decode_nmt.py the appropriate flags are created. Also ensure that the prophetnet tokenizers are created appropriately at the very beginning of the main function.

As for your doubts:

  1. mBART, xprophetnet and xlm-roberta have the same tokenizers. I am not exactly sure what you mean by "my own dataset cannot be language-specific for training". Please give an example of what your data looks like.
  2. If you are worried about a reduction in per language vocab then just have a larger total vocab. After all, mBART has a 250k vocab. If you want a smaller vocab then you may want to deal with some script unification like we did for Indic languages (https://arxiv.org/abs/2109.02903). You need to make a compromise somewhere.

Overall, I hope these points are useful to you. Feel free to make changes, test them and send a PR ;)

koukoulala commented 2 years ago

Very useful answers! I plan to generating a large total vocab and using mBART to pre-train my corpus first, and then introduce xProphetNet model if the performance is not ideal.

Thanks.

prajdabre commented 2 years ago

Hi,

Some tips for you in case you need them:

  1. Ensure vocabulary balance in multilingual settings. If one language has more data than another, then you should probably balance the sizes of data for both languages.
  2. mBART pre-training can be tricky if you have lots of data. You need to mess with learning rates, batch sizes etc to avoid suboptimal training. (I still have problems figuring it out)
  3. I think I will code up xprophetnet in the next few days and make a push so check back in a few days and feel free to use it.

Good luck!