Closed harshyadav17 closed 3 months ago
The code for the distillation procedure is available here.
Yes, it uses a custom fork of fairseq VarunGumma/fairseq as outlined in the Distillation branch.
The share_decoder_input_output_embed
flag is used for weight tying. As the name suggests, it ties the weights of the decoder's input embedding with the weights for the output projection. This is a common technique that helps reduce the number of parameters in the model.
Hey @prajdabre @PranjalChitale
I was going through the paper about the distillation of the models, but couldn't find the relevant source code in this repo. Are we directly using the following repo: https://github.com/VarunGumma/fairseq/tree/main?tab=readme-ov-file with following arguments:
--teacher-checkpoint-path $teacher_ckpt --task translation_with_kd --criterion label_smoothed_cross_entropy_with_kd --kd-args '{"strategy": "word_level"}'
In the paper it is mentioned that we use KL Divergence for the teacher student training. Moreover can you please comment more on Share decoder input output embed.
It would be really helpful if you can share the script/syntax for the training and getting the correct model architecture with weight initialisation.
Thanks!