jxiw / MambaInLlama

[NeurIPS 2024] Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
173 stars 12 forks source link

Why doesn’t kl_div ignore -100 in pseudo labels? #11

Open yynil opened 2 months ago

yynil commented 2 months ago

The original codes looks like below: kl_loss = F.kl_div(F.log_softmax( student_logits, dim=-1), targets, reduction='batchmean')

Although the relative loss curve is the same, the optimizer might try to optimize the unnecessary targets. I think it will be better to utilize the -100 indexes as mask to inputs.

jxiw commented 2 months ago

Thank you for your insightful observation. It absolutely makes sense. You're right that the optimizer might end up focusing on unnecessary targets, which isn't ideal. I haven't fully optimized this part yet, and using the -100 indexes as a mask to filter out irrelevant inputs would definitely be a more efficient approach. Thank you for your suggestion!

yynil commented 1 month ago

actually I doubt if this modification can eliminate the needs to do SFT/DPO. I made a modified version to distill a RWKV in Llama3.1. Let’s see what will happen in this modified loss calculation.

jxiw commented 1 month ago

So, from my understanding, I want to utilize as many labels as possible. What we hope is that, for any given x, we minimize the prediction difference between the teacher and the students. That's why I'm utilizing more labels for this. Although the user's input does not belong model's output, we can still minimize the cross-entropy difference for the next token prediction. Another reason is that there are common debates here and here about whether the instruct model is trained only using completing the assistant part or doing the next token prediction for the whole sequence. As far as i know, the Zephyr style models do the second way. I don't know whether Llama-3-instrcuted is finetuned with only the assistant part or with the whole sequence. But if the second is the case, i think doing KL divergence for the next token on the whole sequence is meaningful.

However, we are open to exploring new approaches. If you think the SFT/DPO part is not necessary, feel free to omit it. From my experience, the SFT part improves performance the most since it uses GPT-4 synthetic labels (only using pseudo label in stage 1 has a bad MMLU score). Another reason why we do SFT/DPO is that when this project starts, Llama3 has not appear and Zephyr is the best 7B instruct model at that time. We do SFT/DPO to mimic it. You can consider SFT/DPO as types of redoing the alignment (instruction tuning).

If you only want to do stage 1, I suggest changing one layer at a time (instead of changing many layers). That gave me better results in distillation compared with the current stage 1. You possibly can also get a good score in some common sense reasoning tasks. But MMLU is pretty challenge to close from my experience. You could try with more labels.

That being said, feel free to try anything! Hope that helps!

yynil commented 1 month ago

Thanks for your kindly help, that's really very important and insightful for me! I also have another question, I tried a very aggressive distillation that distill all 32 layers of LLama in one hybrid training process, but I got a very low student's cross entropy loss and a very high and unstable kl loss no matter I only calculate label's kl or all kl loss. Could you share your experience about the kl-div loss? If the problem lies in the aggressive distillation(32 layers together) ?

jxiw commented 1 month ago

Regarding the KL loss, my best model with 1/8 attention achieves under 50. I am still trying to train it with more data to push it further, but I have limited GPU resources.

Here are some suggestions based on my experience:

1) If you are only interested in non-attention models (and don't want to use any attention), I suggest distilling from the first layer to the second, and then along this way, to the last layer. The stepwise distillation definitely improves form my experience.

2) You may want to consider switching to the Llama3 model instead of Llama3 instructed, as Llama3 instructed uses RLHF for alignment, which is still a bit of a mystery for us. So, you could consider using general corpora like RefineWeb and using Llama3 as the teacher model.

3) You may want to consider adding MSE loss for the hidden output of each layer, similar to what's described in this. But to be honest, from my experiment using stepwise distillation, that does not improve a lot.

I sincerely hope this could help you.

yynil commented 1 month ago

Really insightful and directional instructions! I tried to distill Instruction Llama3.1 with all labels kl div in the first layer, the common sense dialogue seems as good as the original llama. In some special case like comparing 9.11 and 9.8, the hybrid model is even better. I will continue to distill and SFT the rest steps to see what will happen then. Thanks again!

jxiw commented 1 month ago

We are glad to hear that! If you have any further questions, please feel free to reach out at any time.