bigscience-workshop / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.31k stars 213 forks source link

Distill megatron - test Draft WIP #352

Closed younesbelkada closed 1 year ago

younesbelkada commented 1 year ago

An attempt to perform knowledge distillation using Megatron-DeepSpeed

disclaimer: this is a super ugly version of the code, the PR is here to compare the difference between the original code and this modified version - for now I don't plan to merge this PR

Updates on 28.09.2022

This version is very ugly, I had to add an argument student_ on all megatron modules since the arguments are directly retrieved from the global variable. The other solution could be to have each class re-written with the suffix Student - eg GPTModelStudentPipe. I preferred the first solution to have a quick working implementation.

The forward and backward pass seems to pass for the student model - for now I am not computing the teacher's logits.. Two solutions for that

1 - In distill_train_step - add a step where we retrieve the teacher's logits. In this case would we need to change the deepspeed.PipelineEngine internals? 2- Store the embedding layer of the teacher model inside the student model and gather the last hidden states of the teacher model. Once this is gathered apply the forward pass with the embedding layer of the teacher model to get the logits. (cc @thomasw21 as discussed offline)

main TODOs