kitzeslab / opensoundscape

Open source, scalable software for the analysis of bioacoustic recordings
http://opensoundscape.org
MIT License
137 stars 16 forks source link

Finetuning Pre-trained Models such as BirdNET or Perch #896

Closed sunraymoonbeam closed 3 weeks ago

sunraymoonbeam commented 1 year ago

Hi guys,

Thank you for providing such a useful library. I believe that this will help me greatly for my Final Year Project for a bioacoustics classification problem.

I would like to fine-tune both BirdNET and Perch models on my own data to create my own custom model, is it possible to load BirdNET or Perch's weights while creating a CNN object to train? I noticed that both BirdNET and Perch uses the efficientnet_b0 architecture which cnn_architectures.py supports.

However, I noticed that there is no option for loading pre-trained weights for these bioacoustics models when creating the CNN object to train in cnn_architectures.py.

Screenshot 2023-10-25 at 16 32 24

Is it possible to load BirdNET or perch's weights into these CNN models and further fine-tune / train them? I know that there are some training scripts available on their respective github repos but I would like to utilize opensoundscape's trainer and preprocessing classes.

Would something like this work? Screenshot 2023-10-25 at 16 36 13

sammlapp commented 1 year ago

Hi, this is not possible. In the case of BirdNET, the weights of the model are not open-source. The publicly available model is a tflite model (TensorFlow Lite) which is essentially a compiled object that can create embeddings or class weights, but does not give access to anything inside the model such as weights.

In the case of Perch, the full model object is open-source, but it was trained in Jax then deployed with TensorFlow. Though we hope to eventually port the full Perch model to PyTorch, this is not a trivial conversion. As of now, the models available from bioacoustics-model-zoo are just wrappers around the BirdNET and Perch tensorflow models - they literally run tensorflow inside.

I'll leave this issue open as a "feature request" for a fully pytorch version of Perch, which would be fine-tunable. Likewise, if BirdNET eventually publishes a fully open-source model object we will attempt a similar port into PyTorch

sammlapp commented 1 year ago

On the other hand, if you wish to just use a "frozen" feature extractor and train a classification head (only train the weights of the final classification layers, not the rest of the network), this can be easily accomplished by using Perch or Birdnet to generate embeddings, and fitting a simple final classification layer like sklearn.linear_model.LogisticRegression to map from the embeddings to class labels

sammlapp commented 5 months ago

It's worth mentioning that HawkEars is a recently developed open-source PyTorch bird classifier that is easy to fine tune in PyTorch. You can use the HawkEars github repo or get HawkEars as an OpenSoundscape CNN object by loading it from the bioacoustics model zoo : m = torch.hub.load('kitzeslab/bioacoustics-model-zoo', 'HawkEars',trust_repo=True)