elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Finetuning Whisper with Bumblebee / Axon #265

Closed omarelb closed 7 months ago

omarelb commented 8 months ago

Hi! Thanks for the great work on Bumblebee! I'm having some trouble figuring out how to finetune Whisper with Bumblebee and Axon. I have previously finetuned Whisper using Python (this script being a great help) but would love to figure out how to do it in Elixir.

The part I have most trouble with is implementing the logic around getting the input into the model and getting the logits out such that the loss can be calculated over the sequence. It seems like the module Bumblebee.Audio.SpeechToTextWhisper handles a lot of this logic, but is mostly concerned with generating the Serving. How involved would it be to write the training loop for Whisper? I would greatly appreciate any pointers in the right direction. Thanks!

jonatanklosko commented 7 months ago

Hey, sorry I forgot to reply! We have an example for fine-tuning bert. The whisper model already returns the model, you can do logits_model = Axon.nx(model, & &1.logits) to return the logits for training directly, as in the example. You'll need to build the logic around preparing the data, probably something similar to the Python script. Note that the serving is just for inference and it also handles different concerns; the only relevant part is the featurizer, which you'd use on the training data.

If anyone has more ideas or examples, feel free to post here :)

omarelb commented 7 months ago

Thanks Jonatan, I'll see if I can get something working :)