kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
287 stars 48 forks source link

Fine tuning on novel dataset #14

Open beyondbeneath opened 2 years ago

beyondbeneath commented 2 years ago

Hello. Firstly, thank you for this great work! I've already had very promising results looking at the "scene" embeddings from these models, and looking to fine tune a model on a new dataset - similar to ESC50 & others. (as a side note, using scene embeddings & a logistic regression, I'm having acceptably good results, however I'm convinced true fine tuning would be significantly better).

I'm having a bit of trouble interpreting the example scripts. Are you able to give a simple explanation of what is required for fine-tuning (e.g. the data format, directories vs JSON file, formal of labels CSV, etc)? It's quite hard to reverse engineer this from the code. I have a directory of files, and known labels, and simply want to fine tune a model on it. And once the data is in place, which functions/CLI scripts should be invoked?

Many thanks, and if I'm missing something obvious, apologies. I know the Audioset page has a few more details but it's still not crystal clear how to proceed. Cheers!

kkoutini commented 2 years ago

Hi thank you! You can use replace the AudioSetDataset with any class compatible with torch.utils.data.Dataset. You just need to make sure the getitem method returns a tuple similar to waveform.reshape(1, -1), filename, target. For example in AudioSetDataset meta_csv is expected to have the columns filename and target. train is a boolean to filter the train and test samples here.

The important method in the dataset file are get_training_set and get_test_set. which are called from here like dataset=CMD("/basedataset.get_training_set").

If you want to use the framework to fine-tune the model on a new dataset, I'd recommend:

  1. making a copy of esc50/dataset.py and edit the AudioSetDataset class to correctly parse your dataset. Alternatively, you can simply make the methods get_training_set and get_test_set return your custom datsets.
  2. make a copy of ex_esc50.py and edit the dataset config here to point to your new dataset (it should be a python path to your dataset ingredient
beyondbeneath commented 2 years ago

Thanks for this information!

At the moment, I'm trying to get this running in Google Colab, and it appears it will take some time to resolve the conda/pip/mamba environment, as well as removing all the MongoDB requirements, which I imagine will be quite difficult to disentangle.

I will report back if I manage to get it working, and how I go with the fine tuning. Cheers!

ElliottP-13 commented 2 years ago

Hello. I was hoping to use your model to train on the good sounds dataset (musical instrument recordings, stored in .wav file). At the beginning of epoch 2 (epoch 1 trains successfully which is more confusing) the program crashes with error

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)'

The stack trace is fully within pytorch lightning. I can provide it if needed, but I imagine the problem is elsewhere. I was wondering if it was some sort of tensor shape error? It gives me an Image size warning

UserWarning: Input image size (128*500) doesn't match model (128*998). warnings.warn(f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") 

I don't know where these dimensions came from. What should the shapes of the dataset.getitem be? My understanding is that: waveform.reshape(1,-1) essentially flattens the data, then row.filename is a string, and target is a single number (long).

Any help or insight would be greatly appreciated!

kkoutini commented 2 years ago

Hi, I'm not sure what can cause CUDA error. It can be that you are feeding longer audio clips to the model (longer than 10 seconds)?

the 128x500 are the input spectrogram resolution. the pretrained model expects 128 mel-frequency bins time ~1000 time frames (corresponding to 10 seconds). But it shouldn't be a problem to fine tune the model on shorter audio clips.