YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 202 forks source link

Prediction always wrong using esc50 recipe with 0.95+ accuracy after training #49

Closed kremHabashy closed 2 years ago

kremHabashy commented 2 years ago

Thank you for this paper, it is very well written and documented. Sorry for the confusing title. I ran the esc50 recipe and it worked as expected. This is the accuracy obtained per forld: <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

9.50E-01 9.83E-01 9.35E-01 9.70E-01 9.43E-01 9.56E-01

I am trying to use the best model produced to manually classify some audio files (later I want to use the model on my own dataset). This is the code I am running:

torch.cuda.set_device('cuda:0')
device = torch.device("cuda:0")

pretrained_mdl_path ="/home/habashyk/virtualEnvs/ast/egs/esc50/exp/test-esc50-f10-t10-impTrue-aspTrue-b48-lr1e-5/fold2/models/best_optim_state.pth"
sd = torch.load(pretrained_mdl_path, map_location=device)

ast_mdl = ASTModel(label_dim=50,
                   fstride=10,
                   tstride=10,
                   input_fdim=audio_conf_AS["num_mel_bins"],
                   input_tdim=audio_conf_AS["target_length"],
                   imagenet_pretrain=True,
                   model_size='base384')
ast_mdl = torch.nn.DataParallel(ast_mdl)
ast_mdl.load_state_dict(sd, strict=False)
ast_mdl.cuda()
ast_mdl.eval()

Unfortunately, when I use this model, the predictions are never accurate (but have very high probabilities)

Top 3 labels and their associated probabilities for each prediction
THIS IS BATCH 0
Wav 0: Ground truth:  dog
Label:  Cough   Prob:  0.80712890625
Label:  Female speech, woman speaking   Prob:  0.74560546875
Label:  Throat clearing     Prob:  0.7314453125
Wav 1: Ground truth:  chirping_birds
Label:  Child singing   Prob:  0.78369140625
Label:  Cough   Prob:  0.73779296875
Label:  Sneeze  Prob:  0.720703125
Wav 2: Ground truth:  vacuum_cleaner
Label:  Narration, monologue    Prob:  0.67578125
Label:  Children shouting   Prob:  0.6611328125
Label:  Baby laughter   Prob:  0.66015625

I have used the same code with the audioset model and its associated .pth weights and it works fine. Any insigt on this would be greatly appreciated. Please let me know of anything else I can provide.

Also, using audioset_pretrain = True has the same result of high probabilities with incorrect classes.

Thank you!

YuanGongND commented 2 years ago

Hi there,

Maybe I need to see the full inference code. Have you normalized your input?

You can refer to https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py but you need to change the mean/std for normalization.

-Yuan

YuanGongND commented 2 years ago

Just want to follow up, have you solved the problem yet? Thanks!