YuanGongND / ssast

Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
365 stars 61 forks source link

Trouble with pure inference #4

Open beyondbeneath opened 2 years ago

beyondbeneath commented 2 years ago

Hello!

Firstly, thanks for this great work!

I managed to modify the AudioSet fine tuning script, and fine tuned a model on a new audio binary classification task. I started with the "Tiny" Patch model and used a batch size of 2. The resulting predictions on the evaluation set looked very promising!.

I'm now trying to write an inference script, to take that saved model to perform inferences, and running into some trouble. Which method do I actually need to call for pure inference? From the documentation it seems to describe only pre-training or fine-tuning, not inference.

More pressing, I can't actually get the model to load. I am trying to load the best_audio_model.pth as follows:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)

however this results in the errors:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[/content/ssast/src/models/ast_models.py](https://localhost:8080/#) in __init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
    146                 p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
--> 147                 p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
    148             except:

KeyError: 'module.p_input_fdim'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
1 frames
[/content/ssast/src/models/ast_models.py](https://localhost:8080/#) in __init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
    147                 p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
    148             except:
--> 149                 raise  ValueError('The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.')
    150 
    151             print('now load a SSL pretrained models from ' + load_pretrained_mdl_path)

ValueError: The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.

Is there anything obvious I'm missing or doing wrong? Would appreciate any guidance on how to load this model, and also perform an inference on a new .wav file. Thanks!

YuanGongND commented 2 years ago

Thanks for the kind words.

We use multiple GPU to train the model, so the model is that torch.nn.Dataparallel object. Even though you want to do single GPU inference, you need to do following:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# then do inference as normal
output  = ast_mdl(input)

Another method is to convert torch.nn.Dataparallel models back to normal torch.model objects. You can search online for the solution.

-Yuan

YuanGongND commented 2 years ago

Also the model input should be a spectrogram that is processed with the same normalization and feature extraction function https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/dataloader.py#L195 and https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/dataloader.py#L126-L127.

You can also refer to https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py.

beyondbeneath commented 2 years ago

Thanks Yuan for your suggestions.

To be clear, where do I add the DataParallel wrapper? In your example, you put it after the ASTModel object, however it is in that initial call where it is failing the load, therefore I suspect I need to modify ast_models.py, or are you suggesting alternatively convert the serialised model to a parallel one?

YuanGongND commented 2 years ago

You should do something like this: https://github.com/YuanGongND/ast/blob/7b2fe7084b622e540643b0d7d7ab736b5eb7683b/egs/audioset/inference.py#L82-L89

i.e., audio_model.load_state_dict(checkpoint) after convert it to Dataparallel object.

YuanGongND commented 2 years ago

I don't suggest changing ast_models.py. Somehting like below should work:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model.load_state_dict(checkpoint)
# then do inference as normal
output  = ast_mdl(input)
beyondbeneath commented 2 years ago

Sorry I might not have been clear here.

The ast_mdl = ASTModel(...) is the line which is failing. Therefore I cannot convert it after, since that line never runs.

In that call, MODEL is the saved model file from the experiment/models/best_audio_model.pth saved during my previous fine tuning.

Or are you suggesting that I load the pre-trained model provided by this repo (SSAST-Tiny-Patch-400) and then load on top of this my checkpoint (and if so, is the checkpoint the best_audio_model.pth or best_optim_state.pth?

Does that make sense?

YuanGongND commented 2 years ago

That's weird, if you use my recipe to fine-tune the model, the saved model should be already a dataparallel object.

beyondbeneath commented 2 years ago

Yes... I most certainly used your code.

I essentially used the AudioSet fine tune script - here is the full .sh file with my modifications (not many):


set -x
export TORCH_HOME=../../pretrained_models
mkdir -p ./exp

if [ -e SSAST-Tiny-Patch-400.pth ]
then
    echo "pretrained model already downloaded."
else
    wget https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Tiny-Patch-400.pth?dl=1 -O SSAST-Tiny-Patch-400.pth
fi

pretrain_exp=
pretrain_model=SSAST-Tiny-Patch-400
pretrain_path=./${pretrain_exp}/${pretrain_model}.pth

dataset=testdata

dataset_mean=-4.2677393
dataset_std=4.5689974
target_length=1024
noise=False

task=ft_avgtok
model_size=base
head_lr=1
warmup=True

bal=none
lr=5e-5
epoch=25
tr_data=/content/drive/MyDrive/ssast_train1.json
te_data=/content/drive/MyDrive/ssast_val1.json
freqm=48
timem=192
mixup=0.5
fstride=10
tstride=10
fshape=16
tshape=16
batch_size=2
exp_dir=./exp/test01-${dataset}-f${fstride}-${fshape}-t${tstride}-${tshape}-b${batch_size}-lr${lr}-${task}-${model_size}-${pretrain_exp}-${pretrain_model}-${head_lr}x-noise${noise}-3

CUDA_CACHE_DISABLE=1 python -W ignore ../../run.py --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv ./data/class_labels_indices.csv --n_class 2 \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model False \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --fshape ${fshape} --tshape ${tshape} --warmup False --task ${task} \
--model_size ${model_size} --adaptschedule False \
--pretrained_mdl_path ${pretrain_path} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --target_length ${target_length} \
--num_mel_bins 128 --head_lr ${head_lr} --noise ${noise} \
--lrscheduler_start 10 --lrscheduler_step 5 --lrscheduler_decay 0.5 --wa True --wa_start 6 --wa_end 25 \
--loss BCE --metrics mAP
YuanGongND commented 2 years ago

I see. It might be caused by a bug in the code. I didn't consider your use case.

If the model is not too large, can you send the .pth file to me at yuangong@mit.edu?

I can take a look, but not immediately, I will need to find some spare time.

beyondbeneath commented 2 years ago

FWIW, I got some kind of inference pipeline running - although the results do not match the output originally generated in your recipes for fine tuning, so I'm guessing there's major bugs in what I got working. But I thought it might be relevant anyway. This is all done after I successfully ran the fine-tuning scripts on a new dataset for binary classification:

First, use the same JSON style approach to make a dataloader (using your dataloader AudioDataset):

audio_conf = {
    'num_mel_bins': 128,
    'target_length': 1024,
    'freqm': 48,
    'timem': 192,
    'mixup': 0.5,
    'dataset': 'testdata',
    'mode':'evaluation',
    'mean':-4.2677393,
    'std':4.5689974,
    'noise':False
    }

train_loader = torch.utils.data.DataLoader(
    dataloader.AudioDataset(val_json,
                            label_csv=labels_csv,
                            audio_conf=audio_conf
                            ),
    batch_size=1,
    shuffle=False)

Next, load the original (pre-trained) model from which I fine-tuned from:

input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
                   fshape=16,
                   tshape=16,
                   fstride=10,
                   tstride=10,
                   input_fdim=128,
                   input_tdim=input_tdim,
                   model_size='tiny',
                   pretrain_stage=False,
                   load_pretrained_mdl_path='SSAST-Tiny-Patch-400.pth')

Then, load into this the state checkpoint (no idea if this works as expected, but is the only way I got anything to run without errors):

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load('best_optim_state.pth`, map_location=device)
if not isinstance(ast_mdl, torch.nn.DataParallel):
  ast_mdl = torch.nn.DataParallel(ast_mdl)
ast_mdl.load_state_dict(sd, strict=False)

Then, copying bits and pieces from the supplied traintest.py:

ast_mdl.eval()
with torch.no_grad():
  for i, (audio_input, labels) in enumerate(train_loader):
    prediction = torch.sigmoid(ast_mdl(audio_input, task='ft_avgtok')).to('cpu').detach()
    print(prediction.shape)
    print(np.array(prediction))

For single samples. the prediction probabilities do not sum to 1 nor do they match my expected values from the supplied recipe for fine tuning.

YuanGongND commented 2 years ago

The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.

https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/models/ast_models.py#L146-L147

For a temporal workaround, you can change these two lines of code:

https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/models/ast_models.py#L146-L147

I will find a time to fix it.

fanOfJava commented 1 year ago

The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.

https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/models/ast_models.py#L146-L147

For a temporal workaround, you can change these two lines of code:

https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/models/ast_models.py#L146-L147

I will find a time to fix it.

even I change these two lines,I still can not load the finetune-ed model.

Lindar1994 commented 1 year ago

Hi, I am running into exactly the same error and have trouble to load a finetuned model. I am wondering if @beyondbeneath ever found a solution to this?

ALexanderSpiridonov commented 1 month ago

Hi, I have the same issue. Any updates or recommendations how to fix it will be highly appreciated :)