YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 202 forks source link

Question about pre-training on a new dataset. #47

Closed devesh-k closed 2 years ago

devesh-k commented 2 years ago

Hi , I am trying to use the pre-trained model on my own dataset and in my own pipeline . As recommended I am using - audioset_pretrain=True and imagenet_pretrain=True . In the code I noticed that we call the ASTmodel again that results in an infinite loop. (line 129 in ast_models.py). below is the snipped that I am referring to : Is this a bug or an oversight on my part . Can you pls take a look ? I am really looking forward to try AST on my pipeline .

`# now load a model that is pretrained on both ImageNet and AudioSet
        elif audioset_pretrain == True:
            if audioset_pretrain == True and imagenet_pretrain == False:
                raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
            if model_size != 'base384':
                raise ValueError('currently only has base384 AudioSet pretrained model.')
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if os.path.exists('../../pretrained_models/audioset_10_10_0.4593.pth') == False:
                # this model performs 0.4593 mAP on the audioset eval set
                audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
                wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
            sd = torch.load('../../pretrained_models/audioset_10_10_0.4593.pth', map_location=device)
            **audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)`**

Thanks in advance for your time.

MikeKras commented 2 years ago

Hi @devesh-k,

although I am not the author I think I can answer that question - this will not cause an infinite loop.

The line that you refer to calls ASTModel with audioset_pretrain=False, which ensures it will be skipped when checking elif audioset_pretrain == True

I ran the pipeline and it worked, so hopefully it will for you as well :)

YuanGongND commented 2 years ago

Yes, I agree.

The reason is in https://github.com/YuanGongND/ast/blob/840f0d74a7ecb2fad4c333b79dd5239711b203f2/src/models/ast_models.py#L127, audioset_pretrain is set to false, therefore it won't go into the loop again.

If you have the error, please paste the code you create the AST model and the error message. Thanks!

devesh-k commented 2 years ago

Thanks so much for taking a look. I am sorry for not being clear . I was referring to the readme , in which it is recommened to change audioset_pretrain == True and I changed the value at two places : Are we not supposed to change the value parameter at constructor level ?

def __init__(self, label_dim=2, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True): when the audio_model is called:

MikeKras commented 2 years ago

If you changed line 127 that Yuan referred to: https://github.com/YuanGongND/ast/blob/840f0d74a7ecb2fad4c333b79dd5239711b203f2/src/models/ast_models.py#L127

then indeed it would be problematic, since it would generate a loop.

The purpose of this second call within init (which should have audioset_pretrain=False) as I understand it, is to start with the graph representation (and weights, but we will replace them anyway) of the base DeiT model, so that we can update it with the changes necessary to run the AST.

If what you mean here

when the audio_model is called:

audio_model = ASTModel(label_dim=2, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True)

is simply creating the instance of the model in some code that you wrote - it will work fine.

I also don't think it's needed (or recommended) to change the default values in the constructor. The instruction only specifies what value you should set in the call when you create an instance.

Hope this helps :)

devesh-k commented 2 years ago

thanks for the clarification. I clearly misunderstood and changed the values at both places. I will update my code and try again.