An attempt to perform knowledge distillation using Megatron-DeepSpeed
disclaimer: this is a super ugly version of the code, the PR is here to compare the difference between the original code and this modified version - for now I don't plan to merge this PR
Updates on 28.09.2022
This version is very ugly, I had to add an argument student_ on all megatron modules since the arguments are directly retrieved from the global variable. The other solution could be to have each class re-written with the suffix Student - eg GPTModelStudentPipe. I preferred the first solution to have a quick working implementation.
The forward and backward pass seems to pass for the student model - for now I am not computing the teacher's logits..
Two solutions for that
1 - In distill_train_step - add a step where we retrieve the teacher's logits. In this case would we need to change the deepspeed.PipelineEngine internals?
2- Store the embedding layer of the teacher model inside the student model and gather the last hidden states of the teacher model. Once this is gathered apply the forward pass with the embedding layer of the teacher model to get the logits. (cc @thomasw21 as discussed offline)
main TODOs
[x] load teacher model from a checkpoint - make sure to use a copy of the chkpt
[ ] load student model from a checkpoint - make sure to use a copy of the chkpt
[x] make the broadcasting of the teacher logits (or last hidden states) work.
An attempt to perform knowledge distillation using Megatron-DeepSpeed
disclaimer: this is a super ugly version of the code, the PR is here to compare the difference between the original code and this modified version - for now I don't plan to merge this PR
Updates on 28.09.2022
This version is very ugly, I had to add an argument
student_
on all megatron modules since the arguments are directly retrieved from the global variable. The other solution could be to have each class re-written with the suffixStudent
- egGPTModelStudentPipe
. I preferred the first solution to have a quick working implementation.The forward and backward pass seems to pass for the student model - for now I am not computing the teacher's logits.. Two solutions for that
1 - In
distill_train_step
- add a step where we retrieve the teacher's logits. In this case would we need to change thedeepspeed.PipelineEngine
internals? 2- Store the embedding layer of the teacher model inside the student model and gather the last hidden states of the teacher model. Once this is gathered apply the forward pass with the embedding layer of the teacher model to get the logits. (cc @thomasw21 as discussed offline)main TODOs