YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.17k stars 221 forks source link

After fine-tune a 3-class dataset, how to load its fine-tuned weighted to update pre-trained ast model? #118

Open jmren168 opened 10 months ago

jmren168 commented 10 months ago

Hi @YuanGongND,

In egs/audioset/inference.py, it tells I can load a pre-trained model as follows,

    # 2. load the best model and the weights
    checkpoint_path = args.model_path
    ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
    print(f'[*INFO] load checkpoint: {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path, map_location='cuda')
    audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
    audio_model.load_state_dict(checkpoint)

However, when I fine-tuend the pre-trained model on a 3-class datasets, then reload it again. Error message is,

size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 602, 768]) from checkpoint, the shape in current model is torch.Size([1, 362, 768]).

My python script

    input_tdim=312
    class_num=3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = '/home/...'/

    sd = torch.load(model_path, map_location=device)
    model = ASTModel(label_dim=class_num, fstride=10, tstride=10, input_fdim=128, input_tdim=input_tdim, audioset_pretrain=True, model_size='base384',verbose=False)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(sd)

Do I miss sth here? Any suggestions are appreciated.

YuanGongND commented 10 months ago

it seems it is due to the inconsistent input_tdim in training and inference, could you share the training script (in particular, what is the input_tdim?). Thanks!

jmren168 commented 10 months ago

Thanks for the reply. Here's the training script:

set=full
imagenetpretrain=True
if [ $set == balanced ]
then
  bal=none
  lr=5e-5
  epoch=25
  #tr_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/balanced_train_data_type1_2_mean.json
  tr_data=./data/datafiles/train_data.json
  lrscheduler_start=10
  lrscheduler_step=5
  lrscheduler_decay=0.5
  wa_start=6
  wa_end=25
else
  bal=bal
  lr=1e-5
  epoch=15 #5
  tr_data=./data/datafiles/train_data.json
  lrscheduler_start=4 #2
  lrscheduler_step=1 #1
  lrscheduler_decay=0.25 #0.5
  wa_start=1
  wa_end=15 #5
fi
#te_data=/data/sls/scratch/yuangong/audioset/datafiles/eval_data.json
te_data=./data/datafiles/valid_data.json
freqm=48
timem=62 # 192
mixup=0
# corresponding to overlap of 6 for 16*16 patches
fstride=10
tstride=10
batch_size=4 # 12

dataset_mean=-4.2677393
dataset_std=4.5689974
audio_length=512 #1024
noise=False
YuanGongND commented 10 months ago

If you set audio_length=512 in training, then in inference, shouldn't the input_tdim=312 be 512?

jmren168 commented 10 months ago

It works, and thanks again.

BTW, when I loaded fine-tuned weights to update audioset pretrained model, do I set audioset_pretrain=True or audioset_pretrain=False?

model = ASTModel(label_dim=class_num, fstride=10, tstride=10, input_fdim=128, input_tdim=input_tdim, **audioset_pretrain=True**, model_size='base384',verbose=False)

YuanGongND commented 10 months ago

I guess it doesn't matter.

You can check by

model.load_state_dict(sd, strict=True), so it ensures the new weight fully covers all parameters (so which initial model does not matter).

-Yuan

jmren168 commented 10 months ago

Just setting strict=True forces new weights are loaded. Thanks for the reply.

YuanGongND commented 10 months ago

thanks for letting me know.

-Yuan