kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
303 stars 50 forks source link

Getting started with a custom dataset #33

Open OhadCohen97 opened 1 year ago

OhadCohen97 commented 1 year ago

Hi,

Thank you for your great work!

I want to use PaSST for my custom dataset, different classification task.

Are there any minimal instructions/code for running the model for a different dataset? From which file should I start?

Does PaSST support multi-channel audio wav?

Best

kkoutini commented 1 year ago

Hi, Thank you! If you want to use your own training framework, check out this repo for some examples: https://github.com/kkoutini/passt_hear21 For example, to get a trainable model that accepts wave inputs:

from hear21passt.base import load_model

model = load_model(mode="logits").cuda()
logits = model(wave_signal)

Unfortunately, all the pre-trained models are trained on mono audio.

OhadCohen97 commented 1 year ago

Hi, Thank you for your reply!

On the contrary, I would like to use your training framework and fine-tune my dataset.

From which file can I modify the code for running training on my data?

kkoutini commented 1 year ago

I think the simplest way to get started is to do something similar to the ESC-50.

Dataset file

You would need a dataset file: https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py In this file, the dataset config is defined, and the main pytorch dataset class: https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/esc50/dataset.py#L137 This file should also define methods to get the training and test sets: https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/esc50/dataset.py#L255 https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/esc50/dataset.py#L271

The config is injected automatically using sacred.

take a look at https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/esc50/dataset.py#L309 to check if your dataset is parsing and loading the audio files correctly.

Experiment file

Then you need to create the Experiment file, similar to https://github.com/kkoutini/PaSST/blob/main/ex_esc50.py The experiment file has the defaults configs and loads the dataset and the NN model and contains the training logic as pytorch-lightning module .

Here, you need to change the dataset ingredient to match your new dataset (string format like a python import) : https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/ex_esc50.py#L68 you can then change the default project name in wanb config: https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/ex_esc50.py#L71

The reset is to update the pytorch-lightning module if needed: https://github.com/kkoutini/PaSST/blob/d27d83228b96b3f88f6f7c2402422f781dbf2d36/ex_esc50.py#L106

OhadCohen97 commented 1 year ago

Thank you for your reply.

I managed to fine-tune the model on my costume dataset, and I would like to run an evaluation on my test set (not inference).

I added 'get_test_set()' in the "dataset.py" and also added a test loader in the experiment file ("/basedataset.get_test_set").

I already have a model checkpoint, and I wonder how I can load it so that it can fit into the class "M(Ba3lModule)" and the function "evaluate_only" in the experiment file.

Is there any basic code to run the checkpoint fine-tuned model in the experiment file?

Thanks!

kkoutini commented 1 year ago

Hi that's great! I think the easiest way to load the model is to edit the eval_only function to something like this:

@ex.command
def evaluate_only(load_ckpt_path="path_to_ckpt", _run, _config, _log, _rnd, _seed):
    # force overriding the config, not logged = not recommended
    trainer = get_trainer(logger=get_logger())
    val_loader = get_validate_loader()

    modul = M(ex)
    ## loading pre-trained weights
    checkpoint = torch.load(load_ckpt_path) # maybe with  map_location=torch.device('cpu'))
    modul.load_state_dict(checkpoint['state_dict'] )

    ##
    modul.val_dataloader = None
    trainer.val_dataloaders = None
    print(f"\n\nValidation len={len(val_loader)}\n")
    res = trainer.validate(modul, val_dataloaders=val_loader)
    print("\n\n Validtaion:")
    print(res)

maybe you also want to load only the transformer weights:

    ## loading pre-trained weights
    checkpoint = torch.load(load_ckpt_path) # maybe with  map_location=torch.device('cpu'))
    net_statedict = {k[4:]: v for k, v in checkpoint['state_dict'].items() if k.startswith("net.")}
    modul.net.load_state_dict(net_statedict )

    ##
OhadCohen97 commented 1 year ago

Hi,

Thank you for responding. It works!

Regarding my question about the multi-channel audio case; can I train PaSST on multi-channel from scratch? (not from pre-trained) Does the model/code support training on these kinds of data?

kkoutini commented 1 year ago

yes, you can change the in_channels parameter here https://github.com/kkoutini/PaSST/blob/2ef1fd9f834e17302aebb2611831aacc30c50fcc/models/passt.py#L939C80-L939C80

CodeBot-del commented 1 year ago

Thank you for your reply.

I managed to fine-tune the model on my costume dataset, and I would like to run an evaluation on my test set (not inference).

I added 'get_test_set()' in the "dataset.py" and also added a test loader in the experiment file ("/basedataset.get_test_set").

I already have a model checkpoint, and I wonder how I can load it so that it can fit into the class "M(Ba3lModule)" and the function "evaluate_only" in the experiment file.

Is there any basic code to run the checkpoint fine-tuned model in the experiment file?

Thanks!

Hello, @OhadCohen97 It is great news that you managed to fine-tune your dataset using the framework, can you please explain to me how you achieved it, step by step? Am trying to do the same thing and am lost where to start or what to do. I Would appreciate if you explain step by step from start to finish.