JulesBelveze / bert-squeeze

🛠️ Tools for Transformers compression using PyTorch Lightning ⚡
https://julesbelveze.github.io/bert-squeeze/
78 stars 10 forks source link

[WIP] Make the tool work with long-t5 + booksum #58

Closed tarekziade closed 4 months ago

tarekziade commented 5 months ago

This is a work in progress

JulesBelveze commented 5 months ago

Hey @tarekziade as a starting point I think we could try the approach presented in this paper: PRE-TRAINED SUMMARIZATION DISTILLATION in which they use a 3 terms loss:

tarekziade commented 5 months ago

I've read the paper and they seem to imply that shrink then fine-tune is a good approach.

So before changing the patch here, I am trying this in a prototype on a "regular" t5 model that has 6 encoders and 6 decoders, which I shrink to 3+3, but I can't get it to converge. The loss stagnates and the final model does not work.

https://github.com/tarekziade/distill-t5/blob/main/sft.py

I have to skip metrics before it blows my M1 memory even with the smallest size for the batch during evaluation. Still trying but I can't find the good combo or my script has a bug...

JulesBelveze commented 5 months ago

@tarekziade I'll have a look at your script and see if I can spot a bug or something. Out of curiosity have you tried simply shrinking the decoder part?

tarekziade commented 5 months ago

@JulesBelveze yeah that is my latest run. It's still running and converging around 3.5 loss see the chart

3 decoder layers

JulesBelveze commented 5 months ago

@tarekziade Okay great, that's something! 🤓 Keep me posted on the model's overall performance or if there's anything else I can help with