RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
341 stars 62 forks source link

Getting started with a custom dataset #34

Open OhadCohen97 opened 1 year ago

OhadCohen97 commented 1 year ago

Hi,

Thank you for your excellent work!

I want to use HTS-Audio-Transformer for my custom dataset, different classification task.

Are there any instructions on how to run the model for a different dataset? From which file should I start?

Thanks

RetroCirce commented 1 year ago

Hi, sorry for the late reply. To use the model in different dataset, you need to construct a new dataset loader and dataset class, which you can refer from SEDDataset.

OhadCohen97 commented 1 year ago

Hi, thank you for your response.

Is it ok to use the SCV2_Dataset or DESED_Dataset? I see they are regular dataset classes, which are better for loading my standard audio WAV files. What is the difference from SEDDataset?

Can HTS-AT support multi-channel audio wav?

Thank you.

RetroCirce commented 1 year ago

Hi,

SCV2 is for speech command v2 dataset, and desed is for the sound even detection dataset, and ESC is for the ESC-50 dataset. I think SED dataset for SCV2 might be the best fit from which you can change it into your own dataset.

Yes, it is possible to support multi-channel audio, but first you might need to change the first layer to map more than one channel to the deep feature--> meaning that the pretrained model is no longer workable. Another way is that you can merge multi-channel into the single-channel, or performing the classification on multi-channel and take their average results.

OhadCohen97 commented 1 year ago

Hi, Thank you for getting back to me!

1) Regarding the multi-channel audio case, I have considered using Patch Embed to process each channel and then summing them up so that it can fit within the "forward_features" function. In your experience, Is there another approach that can be taken to establish a connection between the channels for better classification?

2) Which of the hyperparameters in the 'config.py' file I need to consider in order to properly fine-tuning on my dataset (different classification task)?

Thank you.