songmzhang / DSKD

Repo for Paper "Dual-Space Knowledge Distillation for Large Language Models".
25 stars 3 forks source link

Usage with other model combinations #3

Closed botox-100 closed 1 month ago

botox-100 commented 1 month ago

Hi, i read your interesting paper about the dual spaced KD and started already to try out your code. I was able to get things running by downloading all components but I am not sure about some points where I need some guidance:

Q1: I am not sure about the general workflow of a complete knowledge distillation. I assume it is the following:

Q2: I would like to do the same as above with mistral 7 as a teacher and a self generated stripped down variant of mistral 7 as a student. So in general they share the same vocabulary and structure. My goal is to enhance capabilities of the student by doing this but still have a smaller model. Can I apply the same workflow? Is it just copying the shell script and editing it in a way that the mentioned models are used for the distillation task or do I have to adopt other things as well?

Q3: If I would like to do the same with different models e.g. Llama 3 8B 70B as a teacher and Mistral 7B as a student I would assume the following workflow:

Thanks for your help which would be highly appreciated, Thomas

songmzhang commented 1 month ago

Hi, i read your interesting paper about the dual spaced KD and started already to try out your code. I was able to get things running by downloading all components but I am not sure about some points where I need some guidance:

Q1: I am not sure about the general workflow of a complete knowledge distillation. I assume it is the following:

  • Run finetuning (SFT) individual for teacher and student model. Llama2 (teacher), tinyllama (student)
  • Use the KD script for tinyllama to train the finetuned tinyllama model with the finetuned llama2 teacher model (Same vocabulary) Is this correct or do I mix up things?

Q2: I would like to do the same as above with mistral 7 as a teacher and a self generated stripped down variant of mistral 7 as a student. So in general they share the same vocabulary and structure. My goal is to enhance capabilities of the student by doing this but still have a smaller model. Can I apply the same workflow? Is it just copying the shell script and editing it in a way that the mentioned models are used for the distillation task or do I have to adopt other things as well?

Q3: If I would like to do the same with different models e.g. Llama 3 8B 70B as a teacher and Mistral 7B as a student I would assume the following workflow:

  • Make sure all models are in the model_hub
  • Finetune teacher (Llama 3) and Student (Mistral 7) with adapted SFT scripts
  • Generate the logits and run the distillation with the adapted universal logit scripts

Thanks for your help which would be highly appreciated, Thomas

Thanks for your attention to our work! Here are the responses for your questions:

Reply to Q1: The standard workflow for knowledge distillation in this repo is like this:

Reply to Q2: Yes! If you would like to change other models, you can still follow the workflow described above. The only need is to change the CKPT_PATH/CKPT_NAME/CKPT_TYPE with your actual path.

Reply to Q3: If you would like to distill Llama3 to Mistral, you can follow the following process:

Some points that you (might) misunderstand:

Hope our responses can help you well. Please contact me anytime if you have any other questions about our code.

botox-100 commented 1 month ago

Hi songmzhang,,

first thanks a lot for your prompt reaction and the explanations, that makes it much clearer for me! Meanwhile I started to work with the workflows you described but I still need some time because GPU access is limited (as always) ;-).

I came across one issue where I think that there is a little bug in the saving method used for the epochs. If I run sft or dskd for 10 epochs always the 9th epoch is the last one that is saved to disk (even if the log claims different). And it is not only the numbers being wrong also the timestamps of the files are older as expected for the 10th epoch. I had a short look into distillation.py where the corresponding method is located but was not able to find the glitch on the first view.

Thanks a lot for all your effort and I let you know about my results once i am complete,

Thomas

songmzhang commented 1 month ago

Hi songmzhang,,

first thanks a lot for your prompt reaction and the explanations, that makes it much clearer for me! Meanwhile I started to work with the workflows you described but I still need some time because GPU access is limited (as always) ;-).

I came across one issue where I think that there is a little bug in the saving method used for the epochs. If I run sft or dskd for 10 epochs always the 9th epoch is the last one that is saved to disk (even if the log claims different). And it is not only the numbers being wrong also the timestamps of the files are older as expected for the 10th epoch. I had a short look into distillation.py where the corresponding method is located but was not able to find the glitch on the first view.

Thanks a lot for all your effort and I let you know about my results once i am complete,

Thomas

Hi, I think this is because our distillation.py saves the best checkpoint during the training, not the last. So the 9-th checkpoint may be the best one during your training (highest Rouge-L or lowest validation loss). For this feature, we set a training argument --keep-best-n-checkpoints to save the top-n checkpoints and it was set to 1 in both scripts you run. Simply, you can just modify this line to adjust the max number for saving checkpoints.

botox-100 commented 1 month ago

Thanks songmzhang,

that clearly explains the behavior!

Have a nice sunday!