YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.07k stars 205 forks source link

Error reshaping positional embedding for AudioSet pretrained model #27

Closed devksingh4 closed 2 years ago

devksingh4 commented 2 years ago

This error only occurs when using the AudioSet pretrained model - does not occur using only ImageNet pretrained. Audio is resampled to 16k Hz. Error occurs in src/models/ast_models.py - since t_dim > 101, else block on line 139 is triggered.

Traceback (most recent call last):
  File "train.py", line 73, in <module>
    model = VTN(**vars(cfg))
[REDACTED - model call internally]
  File "/[REDACTED]/ast_models.py", line 141, in __init__
    new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
RuntimeError: shape '[1, 768, 120]' is invalid for input of size 221184

Parameters to "AstModel" instantiation:

label_dim: 400
input_tdim: 251
input_fdim: 64
audioset_pretrain: True
YuanGongND commented 2 years ago

It is a bug, but setting input_fdim=128 should solve the problem, and I think that could also improve the performance.

devksingh4 commented 2 years ago

Hello @YuanGongND , The issue with our data is that we do not have enough audio to use 128 n_mels on the spectrogram, as we end up with mel filterbanks with all zero values. Is there any other workaround for this?

YuanGongND commented 2 years ago

I am not sure why mel filterbanks are all zero value. But I think it is worth try setting input_fdim=128 with audioset_pretrain=True. The AudioSet pretrained model is trained with 128 bins so might not generalize well to 64 bin input anyway. Otherwise, you can just use the ImageNet pretrained model. I would suggest trying both and compare the results.

H-Liu1997 commented 2 years ago

hi, I also got this size-mismatch issue when set input_tdim!=1024, just run ast_model.py can reproduce this error.

YuanGongND commented 2 years ago

Hi,

How did you initialize the AST model? You need to specify the input_tdim when you initialize the AST model and your actual input length needs to match the input_tdim.

-Yuan

H-Liu1997 commented 2 years ago

yeah, thank you for your reply! I specify the input_tdim = 256 with imagenet_pretrain && audioset_pretrain = True, after the following modification it can work on my machine.

            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
            # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
#             if t_dim < 101:
#                 print(new_pos_embed.shape)
#                 new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
#                 print(new_pos_embed.shape)
#             # otherwise interpolate
#             else:
#                 new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
            new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
YuanGongND commented 2 years ago

Hi @H-Liu1997,

Your solution can avoid the error, but would also cause a performance drop. I don't suggest commenting out the positional embedding adaptation code.

Could you share your code of initializing the AST model and forward pass (specifically your input shape)?

-Yuan

H-Liu1997 commented 2 years ago

Could you share your code of initializing the AST model and forward pass (specifically your input shape)?

@YuanGongND Sure! Thank you very much for your help!

import torch import torch.nn as nn from torch.cuda.amp import autocast import os import wget

os.environ['TORCH_HOME'] = '../../pretrained_models'

import timm from timm.models.layers import to_2tuple,truncnormal

override the timm package to relax the input shape constraint.

class PatchEmbed(nn.Module): def init(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().init()

    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)
    num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = num_patches

    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

class ASTModel(nn.Module):
"""
The AST model.
:param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
:param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
:param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
:param input_fdim: the number of frequency bins of the input spectrogram
:param input_tdim: the number of time frames of the input spectrogram
:param imagenet_pretrain: if use ImageNet pretrained model
:param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
:param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
"""
def __init__(self, label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True):

    super(ASTModel, self).__init__()
    # assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'

    if verbose == True:
        print('---------------AST Model Summary---------------')
        print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
    # override timm input shape restriction
    timm.models.vision_transformer.PatchEmbed = PatchEmbed

    # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
    if audioset_pretrain == False:
        if model_size == 'tiny224':
            self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'small224':
            self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'base224':
            self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'base384':
            self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
        else:
            raise Exception('Model size must be one of tiny224, small224, base224, base384.')
        self.original_num_patches = self.v.patch_embed.num_patches
        self.oringal_hw = int(self.original_num_patches ** 0.5)
        self.original_embedding_dim = self.v.pos_embed.shape[2]
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

        # automatcially get the intermediate shape
        f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
        num_patches = f_dim * t_dim
        self.v.patch_embed.num_patches = num_patches
        if verbose == True:
            print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
            print('number of patches={:d}'.format(num_patches))

        # the linear projection layer
        new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        if imagenet_pretrain == True:
            new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
            new_proj.bias = self.v.patch_embed.proj.bias
        self.v.patch_embed.proj = new_proj

        # the positional embedding
        if imagenet_pretrain == True:
            # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
            # cut (from middle) or interpolate the second dimension of the positional embedding
            if t_dim <= self.oringal_hw:
                new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
            # cut (from middle) or interpolate the first dimension of the positional embedding
            if f_dim <= self.oringal_hw:
                new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
            # flatten the positional embedding
            new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
            # concatenate the above positional embedding with the cls token and distillation token of the deit model.
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
        else:
            # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
            # TODO can use sinusoidal positional embedding instead
            new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
            self.v.pos_embed = new_pos_embed
            trunc_normal_(self.v.pos_embed, std=.02)

    # now load a model that is pretrained on both ImageNet and AudioSet
    elif audioset_pretrain == True:
        if audioset_pretrain == True and imagenet_pretrain == False:
            raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
        if model_size != 'base384':
            raise ValueError('currently only has base384 AudioSet pretrained model.')
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if os.path.exists('../../Datasets/checkpoints/audioset_10_10_0.4593.pth') == False:
            # this model performs 0.4593 mAP on the audioset eval set
            audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
            wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
        sd = torch.load('../../Datasets/checkpoints/audioset_10_10_0.4593.pth', map_location=device)
        audio_model = ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)

        audio_model.load_state_dict(sd, strict=False)
        audio_model = torch.nn.DataParallel(audio_model)
        self.v = audio_model.module.v
        self.original_embedding_dim = self.v.pos_embed.shape[2]
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

        f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
        num_patches = f_dim * t_dim
        self.v.patch_embed.num_patches = num_patches
        if verbose == True:
            print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
            print('t_dim={:d}, f_dim={:d}'.format(t_dim, f_dim))
            print('number of patches={:d}'.format(num_patches))
        # print(self.v.pos_embed[:, 2:, :].shape)
        new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
        # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding

if t_dim < 101:

print(new_pos_embed.shape)

new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]

print(new_pos_embed.shape)

otherwise interpolate

else:

new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')

        new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
        self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=128):
    test_input = torch.randn(1, 1, input_fdim, input_tdim)
    test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
    test_out = test_proj(test_input)
    f_dim = test_out.shape[2]
    t_dim = test_out.shape[3]
    #print(f_dim,t_dim)
    return f_dim, t_dim

@autocast()
def forward(self, x):
    """
    :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
    :return: prediction
    """
    # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
    x = x.unsqueeze(1)
    x = x.transpose(2, 3)

    B = x.shape[0]
    x = self.v.patch_embed(x)
    cls_tokens = self.v.cls_token.expand(B, -1, -1)
    dist_token = self.v.dist_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, dist_token, x), dim=1)
    x = x + self.v.pos_embed
    x = self.v.pos_drop(x)
    for blk in self.v.blocks:
        x = blk(x)
    x = self.v.norm(x)
    x = (x[:, 0] + x[:, 1]) / 2

    x = self.mlp_head(x)
    return x

if __name__ == '__main__':
input_tdim = 256
ast_mdl = ASTModel(input_tdim=input_tdim)
# input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
print(test_output.shape)

input_tdim = 512
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)
YuanGongND commented 2 years ago

If you keep the original ast_models.py unchanged and run it, will you get an error? I just ran it and didn't get an error. The second example is similar to your case, right?

    input_tdim = 512
    ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
    # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
    test_input = torch.rand([10, input_tdim, 128])
    test_output = ast_mdl(test_input)
    # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
    print(test_output.shape)
YuanGongND commented 2 years ago

And I think switching the order of

audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)

at line 128 could lead to the issue, can you tell me the error message if you don't switch the order?

H-Liu1997 commented 2 years ago

@YuanGongND Sorry for the inconvenience. You are right, switching the order lead to the issue. it will work well if unchanged everything, and just call ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)

now, the code can work well, thank you for your help and sorry for the inconvenience again.

YuanGongND commented 2 years ago

Thanks for the clasrification and great to know the code works.

sreenivasaupadhyaya commented 6 months ago

Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.

`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)

input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins

test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)

output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.

print(test_output.shape)`

and i got the following error!

image

Any help is appreciated. Thanks.

sreenivasaupadhyaya commented 6 months ago

Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.

`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)

input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins

test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)

output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.

print(test_output.shape)`

and i got the following error!

image

Any help is appreciated. Thanks.

This issue is solved. The issue was with the mis match when the input was on the CPU. I had to force the input to be on GPU to solve it. However this didnt had the problem in case of imagenet pretrained mode.