YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 202 forks source link

How to change the kernel size? #43

Closed ooobsidian closed 2 years ago

ooobsidian commented 2 years ago

Hello @YuanGongND, I'm sorry to bother you again.

I would like to ask you a question: How to change the kernel size to change the number of patches, I USE ImageNet pretrained model and NOT USE AudioSet pretrained model, but I have this problem.

x1 torch.Size([64, 149, 768])
self.v.pos_embed torch.Size([1, 202, 768])
Traceback (most recent call last):
  File "train.py", line 394, in <module>
    main()
  File "train.py", line 161, in main
    train_loss,train_acc = train(train_loader, model, criterion, optimizer, args.use_cuda, epoch)
  File "train.py", line 296, in train
    output = model(inputs)
  File "/root/miniconda3/envs/deepAST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/miniconda3/envs/deepAST/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 159, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/root/miniconda3/envs/deepAST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/miniconda3/envs/deepAST/lib/python3.7/site-packages/torch/cuda/amp/autocast_mode.py", line 135, in decorate_autocast
    return func(*args, **kwargs)
  File "/data/source/deepAST_exp/model/ASTConcat.py", line 180, in forward
    x1 = x1 + self.v.pos_embed
RuntimeError: The size of tensor a (149) must match the size of tensor b (202) at non-singleton dimension 1

I only changed the get_shape function, like this

def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024, kernel_size=(8,8)):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=kernel_size, stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

So, What is the correct way to do this? Looking forward to your answer.

YuanGongND commented 2 years ago

Hi there,

The short answer is no, it is not supported yet in our code, and we do not have a plan to add that feature.

The reason is that when you change the kernel size, you change the patch splitting strategy, e.g., when you set kernel size as (8,8), you are using 8*8 patches rather than 16*16 patches, while the ImageNet pretrained models are pretrained on 16*16 patches. Therefore, the patch splitting layer will complain.

If you don't use ImageNet pertaining, AST should support arbitrary patch sizes, you will at least change this line in addition to the get_shape function. You might need to change other things. But without ImageNet pertaining, the AST performance is not competitive.

Finally, I want to point to our recent work of SSAST, which I hope will be released soon. SSAST uses self-supervised pertaining as a replacement of ImageNet pertaining, which does not constrain the patch size to be 16*16. Nevertheless, in SSAST, we have not tried 8*8 patch either (so no 88 pretrained model will be released, you will need to pretrain it by yourself, but that is fully supported) as it is expensive. We do plan to release 128\2 pretrained model.

-Yuan

ooobsidian commented 2 years ago

Thank you very much for your quick reply, it solved my problem. At the same time, I will continue to follow SSAST.