Closed tarekziade closed 4 months ago
Hey @tarekziade as a starting point I think we could try the approach presented in this paper: PRE-TRAINED SUMMARIZATION DISTILLATION in which they use a 3 terms loss:
I've read the paper and they seem to imply that shrink then fine-tune is a good approach.
So before changing the patch here, I am trying this in a prototype on a "regular" t5 model that has 6 encoders and 6 decoders, which I shrink to 3+3, but I can't get it to converge. The loss stagnates and the final model does not work.
https://github.com/tarekziade/distill-t5/blob/main/sft.py
I have to skip metrics before it blows my M1 memory even with the smallest size for the batch during evaluation. Still trying but I can't find the good combo or my script has a bug...
@tarekziade I'll have a look at your script and see if I can spot a bug or something. Out of curiosity have you tried simply shrinking the decoder part?
@JulesBelveze yeah that is my latest run. It's still running and converging around 3.5 loss see the chart
@tarekziade Okay great, that's something! 🤓 Keep me posted on the model's overall performance or if there's anything else I can help with
This is a work in progress