Closed devksingh4 closed 3 years ago
It is a bug, but setting input_fdim=128
should solve the problem, and I think that could also improve the performance.
Hello @YuanGongND , The issue with our data is that we do not have enough audio to use 128 n_mels on the spectrogram, as we end up with mel filterbanks with all zero values. Is there any other workaround for this?
I am not sure why mel filterbanks are all zero value. But I think it is worth try setting input_fdim=128
with audioset_pretrain=True
. The AudioSet pretrained model is trained with 128 bins so might not generalize well to 64 bin input anyway. Otherwise, you can just use the ImageNet pretrained model. I would suggest trying both and compare the results.
hi, I also got this size-mismatch issue when set input_tdim!=1024
, just run ast_model.py
can reproduce this error.
Hi,
How did you initialize the AST model? You need to specify the input_tdim
when you initialize the AST model and your actual input length needs to match the input_tdim
.
-Yuan
yeah, thank you for your reply!
I specify the input_tdim = 256
with imagenet_pretrain && audioset_pretrain = True
, after the following modification it can work on my machine.
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
# if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
# if t_dim < 101:
# print(new_pos_embed.shape)
# new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
# print(new_pos_embed.shape)
# # otherwise interpolate
# else:
# new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
Hi @H-Liu1997,
Your solution can avoid the error, but would also cause a performance drop. I don't suggest commenting out the positional embedding adaptation code.
Could you share your code of initializing the AST model and forward pass (specifically your input shape)?
-Yuan
Could you share your code of initializing the AST model and forward pass (specifically your input shape)?
@YuanGongND Sure! Thank you very much for your help!
batch_size*128*1024
and it worked wellbatch_size*128*128
and use pretrained weightsast_model.py
and call ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True)
for initialization.# will raise load_state_dict error on my machine
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
# to
audio_model.load_state_dict(sd, strict=False)
audio_model = torch.nn.DataParallel(audio_model)
# -*- coding: utf-8 -*-
# @Time : 6/10/21 5:04 PM
# @Author : Yuan Gong
# @Affiliation : Massachusetts Institute of Technology
# @Email : yuangong@mit.edu
# @File : ast_models.py
import torch import torch.nn as nn from torch.cuda.amp import autocast import os import wget
import timm from timm.models.layers import to_2tuple,truncnormal
class PatchEmbed(nn.Module): def init(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().init()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class ASTModel(nn.Module):
"""
The AST model.
:param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
:param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
:param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
:param input_fdim: the number of frequency bins of the input spectrogram
:param input_tdim: the number of time frames of the input spectrogram
:param imagenet_pretrain: if use ImageNet pretrained model
:param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
:param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
"""
def __init__(self, label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True):
super(ASTModel, self).__init__()
# assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
if verbose == True:
print('---------------AST Model Summary---------------')
print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
# override timm input shape restriction
timm.models.vision_transformer.PatchEmbed = PatchEmbed
# if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
if audioset_pretrain == False:
if model_size == 'tiny224':
self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'small224':
self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base224':
self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base384':
self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
else:
raise Exception('Model size must be one of tiny224, small224, base224, base384.')
self.original_num_patches = self.v.patch_embed.num_patches
self.oringal_hw = int(self.original_num_patches ** 0.5)
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
# automatcially get the intermediate shape
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('number of patches={:d}'.format(num_patches))
# the linear projection layer
new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
if imagenet_pretrain == True:
new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
new_proj.bias = self.v.patch_embed.proj.bias
self.v.patch_embed.proj = new_proj
# the positional embedding
if imagenet_pretrain == True:
# get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
# cut (from middle) or interpolate the second dimension of the positional embedding
if t_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
# cut (from middle) or interpolate the first dimension of the positional embedding
if f_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
# flatten the positional embedding
new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
# concatenate the above positional embedding with the cls token and distillation token of the deit model.
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
else:
# if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
# TODO can use sinusoidal positional embedding instead
new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
self.v.pos_embed = new_pos_embed
trunc_normal_(self.v.pos_embed, std=.02)
# now load a model that is pretrained on both ImageNet and AudioSet
elif audioset_pretrain == True:
if audioset_pretrain == True and imagenet_pretrain == False:
raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
if model_size != 'base384':
raise ValueError('currently only has base384 AudioSet pretrained model.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.exists('../../Datasets/checkpoints/audioset_10_10_0.4593.pth') == False:
# this model performs 0.4593 mAP on the audioset eval set
audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
sd = torch.load('../../Datasets/checkpoints/audioset_10_10_0.4593.pth', map_location=device)
audio_model = ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
audio_model.load_state_dict(sd, strict=False)
audio_model = torch.nn.DataParallel(audio_model)
self.v = audio_model.module.v
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('t_dim={:d}, f_dim={:d}'.format(t_dim, f_dim))
print('number of patches={:d}'.format(num_patches))
# print(self.v.pos_embed[:, 2:, :].shape)
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
# if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=128):
test_input = torch.randn(1, 1, input_fdim, input_tdim)
test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
test_out = test_proj(test_input)
f_dim = test_out.shape[2]
t_dim = test_out.shape[3]
#print(f_dim,t_dim)
return f_dim, t_dim
@autocast()
def forward(self, x):
"""
:param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
:return: prediction
"""
# expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
x = x.unsqueeze(1)
x = x.transpose(2, 3)
B = x.shape[0]
x = self.v.patch_embed(x)
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk in self.v.blocks:
x = blk(x)
x = self.v.norm(x)
x = (x[:, 0] + x[:, 1]) / 2
x = self.mlp_head(x)
return x
if __name__ == '__main__':
input_tdim = 256
ast_mdl = ASTModel(input_tdim=input_tdim)
# input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
print(test_output.shape)
input_tdim = 512
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)
If you keep the original ast_models.py
unchanged and run it, will you get an error? I just ran it and didn't get an error. The second example is similar to your case, right?
input_tdim = 512
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)
And I think switching the order of
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
at line 128 could lead to the issue, can you tell me the error message if you don't switch the order?
@YuanGongND
Sorry for the inconvenience.
You are right, switching the order lead to the issue. it will work well if unchanged everything, and just call ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
_init_
function.now, the code can work well, thank you for your help and sorry for the inconvenience again.
Thanks for the clasrification and great to know the code works.
Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.
`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)
test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)
print(test_output.shape)`
and i got the following error!
Any help is appreciated. Thanks.
Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.
`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)
input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)
output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)`
and i got the following error!
Any help is appreciated. Thanks.
This issue is solved. The issue was with the mis match when the input was on the CPU. I had to force the input to be on GPU to solve it. However this didnt had the problem in case of imagenet pretrained mode.
This error only occurs when using the AudioSet pretrained model - does not occur using only ImageNet pretrained. Audio is resampled to 16k Hz. Error occurs in src/models/ast_models.py - since t_dim > 101, else block on line 139 is triggered.
Parameters to "AstModel" instantiation: