YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 203 forks source link

How can I adapt the pretrained AST model to fit my own dataset #121

Closed zky-66 closed 3 months ago

zky-66 commented 4 months ago

Hello, author, I would like to use my own dataset, where the input is a spectrogram with dimensions [32, 513, 313]. However, the parameters fixed in your AST pretrained model are input_fdim=128 and input_tdim=1024. How can I adapt the pretrained AST model to fit my own dataset? Looking forward to your assistance. ![Uploading image.png…]()

YuanGongND commented 4 months ago

hi there,

Is this an audio spectrogram?

I assume 32 is the batch size, 513 is the time, and 313 is the #number of frequency bins of the filter, right?

zky-66 commented 4 months ago

hi there,

Is this an audio spectrogram?

I assume 32 is the batch size, 513 is the time, and 313 is the #number of frequency bins of the filter, right?

Yes, the following problem occurred: RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 1214, 768]) from checkpoint, the shape in current model is torch.Size([1, 602, 768]).

YuanGongND commented 4 months ago

this error is expected.

The problem is your #number of frequency bins, what is the sampling rate of your audio? If it is 16kHz, please use our code to covert to 128 frequency bins; if it is not, the best way is to convert them to 16kHz and then use our code to convert to 128 frequency bins.

-Yuan

zky-66 commented 4 months ago

the sampling rate of your audio is 16kHz,but how to use your code to convert to 128 frequency bins,I don't know which part of the code to modify.Does the input_tdim=1024 need to be modified?

YuanGongND commented 4 months ago

If you do nothing, the code will convert it to 128, i.e., the shape should be [32, 513, 128].

I suggest to first run the ESC-50 recipe, it is simple, and should not have any problem.

zky-66 commented 4 months ago

If you do nothing, the code will convert it to 128, i.e., the shape should be [32, 513, 128].

I suggest to first run the ESC-50 recipe, it is simple, and should not have any problem.

I simply inserted the following code into my network, where test_input = x_spec.to(device), and the dimensions of x_spec are [32, 513, 313]. Is it correct to use the pretrained model like this?

    pretrained_mdl_path = '/home/zky/ST/pretrained_models/pretrained_models/audioset_10_10_0.4495.pth'
    # get the frequency and time stride of the pretrained model from its name
    fstride, tstride = int(pretrained_mdl_path.split('/')[-1].split('_')[1]), int(
        pretrained_mdl_path.split('/')[-1].split('_')[2].split('.')[0])

    input_tdim = 513
    input_fdim = 313

    if torch.cuda.device_count() > 1:
        device = torch.device("cuda:1")  
    else:
        raise Exception("只找到一块GPU,但是代码尝试使用第二块GPU,请检查系统配置")

    audio_model = ASTModel(input_tdim=input_tdim,input_fdim=input_fdim).to(device)
    audio_model = torch.nn.DataParallel(audio_model, device_ids=[1])  

    sd = torch.load(pretrained_mdl_path, map_location=device)
    audio_model.load_state_dict(sd, strict=False)

    test_input = x_spec.to(device)  
    test_output = audio_model(test_input) 

    print("test_output:",test_output.shape)