AI4Bharat / IndicTrans2

Translation models for 22 scheduled languages of India
https://ai4bharat.iitm.ac.in/indic-trans2
MIT License
214 stars 59 forks source link

Distillation of en-indic base model #74

Closed harshyadav17 closed 3 months ago

harshyadav17 commented 3 months ago

Hey @prajdabre @PranjalChitale

I was going through the paper about the distillation of the models, but couldn't find the relevant source code in this repo. Are we directly using the following repo: https://github.com/VarunGumma/fairseq/tree/main?tab=readme-ov-file with following arguments:

--teacher-checkpoint-path $teacher_ckpt --task translation_with_kd --criterion label_smoothed_cross_entropy_with_kd --kd-args '{"strategy": "word_level"}'

In the paper it is mentioned that we use KL Divergence for the teacher student training. Moreover can you please comment more on Share decoder input output embed.

It would be really helpful if you can share the script/syntax for the training and getting the correct model architecture with weight initialisation.

Thanks!

PranjalChitale commented 3 months ago

The code for the distillation procedure is available here.

Yes, it uses a custom fork of fairseq VarunGumma/fairseq as outlined in the Distillation branch.

The share_decoder_input_output_embed flag is used for weight tying. As the name suggests, it ties the weights of the decoder's input embedding with the weights for the output projection. This is a common technique that helps reduce the number of parameters in the model.