Open beyondbeneath opened 2 years ago
Thanks for the kind words.
We use multiple GPU to train the model, so the model is that torch.nn.Dataparallel object. Even though you want to do single GPU inference, you need to do following:
input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
fshape=16,
tshape=16,
fstride=10,
tstride=10,
input_fdim=128,
input_tdim=input_tdim,
model_size='tiny',
pretrain_stage=False,
load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# then do inference as normal
output = ast_mdl(input)
Another method is to convert torch.nn.Dataparallel models back to normal torch.model objects. You can search online for the solution.
-Yuan
Also the model input should be a spectrogram that is processed with the same normalization and feature extraction function https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/dataloader.py#L195 and https://github.com/YuanGongND/ssast/blob/35ae7abbdd2870c008feed4ece8b7c6457421b17/src/dataloader.py#L126-L127.
You can also refer to https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py.
Thanks Yuan for your suggestions.
To be clear, where do I add the DataParallel wrapper? In your example, you put it after the ASTModel object, however it is in that initial call where it is failing the load, therefore I suspect I need to modify ast_models.py
, or are you suggesting alternatively convert the serialised model to a parallel one?
You should do something like this: https://github.com/YuanGongND/ast/blob/7b2fe7084b622e540643b0d7d7ab736b5eb7683b/egs/audioset/inference.py#L82-L89
i.e., audio_model.load_state_dict(checkpoint)
after convert it to Dataparallel object.
I don't suggest changing ast_models.py
.
Somehting like below should work:
input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
fshape=16,
tshape=16,
fstride=10,
tstride=10,
input_fdim=128,
input_tdim=input_tdim,
model_size='tiny',
pretrain_stage=False,
load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model.load_state_dict(checkpoint)
# then do inference as normal
output = ast_mdl(input)
Sorry I might not have been clear here.
The ast_mdl = ASTModel(...)
is the line which is failing. Therefore I cannot convert it after, since that line never runs.
In that call, MODEL
is the saved model file from the experiment/models/best_audio_model.pth
saved during my previous fine tuning.
Or are you suggesting that I load the pre-trained model provided by this repo (SSAST-Tiny-Patch-400
) and then load on top of this my checkpoint (and if so, is the checkpoint the best_audio_model.pth
or best_optim_state.pth
?
Does that make sense?
That's weird, if you use my recipe to fine-tune the model, the saved model should be already a dataparallel object.
Yes... I most certainly used your code.
I essentially used the AudioSet fine tune script - here is the full .sh
file with my modifications (not many):
set -x
export TORCH_HOME=../../pretrained_models
mkdir -p ./exp
if [ -e SSAST-Tiny-Patch-400.pth ]
then
echo "pretrained model already downloaded."
else
wget https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Tiny-Patch-400.pth?dl=1 -O SSAST-Tiny-Patch-400.pth
fi
pretrain_exp=
pretrain_model=SSAST-Tiny-Patch-400
pretrain_path=./${pretrain_exp}/${pretrain_model}.pth
dataset=testdata
dataset_mean=-4.2677393
dataset_std=4.5689974
target_length=1024
noise=False
task=ft_avgtok
model_size=base
head_lr=1
warmup=True
bal=none
lr=5e-5
epoch=25
tr_data=/content/drive/MyDrive/ssast_train1.json
te_data=/content/drive/MyDrive/ssast_val1.json
freqm=48
timem=192
mixup=0.5
fstride=10
tstride=10
fshape=16
tshape=16
batch_size=2
exp_dir=./exp/test01-${dataset}-f${fstride}-${fshape}-t${tstride}-${tshape}-b${batch_size}-lr${lr}-${task}-${model_size}-${pretrain_exp}-${pretrain_model}-${head_lr}x-noise${noise}-3
CUDA_CACHE_DISABLE=1 python -W ignore ../../run.py --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv ./data/class_labels_indices.csv --n_class 2 \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model False \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --fshape ${fshape} --tshape ${tshape} --warmup False --task ${task} \
--model_size ${model_size} --adaptschedule False \
--pretrained_mdl_path ${pretrain_path} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --target_length ${target_length} \
--num_mel_bins 128 --head_lr ${head_lr} --noise ${noise} \
--lrscheduler_start 10 --lrscheduler_step 5 --lrscheduler_decay 0.5 --wa True --wa_start 6 --wa_end 25 \
--loss BCE --metrics mAP
I see. It might be caused by a bug in the code. I didn't consider your use case.
If the model is not too large, can you send the .pth file to me at yuangong@mit.edu?
I can take a look, but not immediately, I will need to find some spare time.
FWIW, I got some kind of inference pipeline running - although the results do not match the output originally generated in your recipes for fine tuning, so I'm guessing there's major bugs in what I got working. But I thought it might be relevant anyway. This is all done after I successfully ran the fine-tuning scripts on a new dataset for binary classification:
First, use the same JSON style approach to make a dataloader (using your dataloader AudioDataset
):
audio_conf = {
'num_mel_bins': 128,
'target_length': 1024,
'freqm': 48,
'timem': 192,
'mixup': 0.5,
'dataset': 'testdata',
'mode':'evaluation',
'mean':-4.2677393,
'std':4.5689974,
'noise':False
}
train_loader = torch.utils.data.DataLoader(
dataloader.AudioDataset(val_json,
label_csv=labels_csv,
audio_conf=audio_conf
),
batch_size=1,
shuffle=False)
Next, load the original (pre-trained) model from which I fine-tuned from:
input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
fshape=16,
tshape=16,
fstride=10,
tstride=10,
input_fdim=128,
input_tdim=input_tdim,
model_size='tiny',
pretrain_stage=False,
load_pretrained_mdl_path='SSAST-Tiny-Patch-400.pth')
Then, load into this the state checkpoint (no idea if this works as expected, but is the only way I got anything to run without errors):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load('best_optim_state.pth`, map_location=device)
if not isinstance(ast_mdl, torch.nn.DataParallel):
ast_mdl = torch.nn.DataParallel(ast_mdl)
ast_mdl.load_state_dict(sd, strict=False)
Then, copying bits and pieces from the supplied traintest.py
:
ast_mdl.eval()
with torch.no_grad():
for i, (audio_input, labels) in enumerate(train_loader):
prediction = torch.sigmoid(ast_mdl(audio_input, task='ft_avgtok')).to('cpu').detach()
print(prediction.shape)
print(np.array(prediction))
For single samples. the prediction probabilities do not sum to 1 nor do they match my expected values from the supplied recipe for fine tuning.
The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.
For a temporal workaround, you can change these two lines of code:
I will find a time to fix it.
The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.
For a temporal workaround, you can change these two lines of code:
I will find a time to fix it.
even I change these two lines,I still can not load the finetune-ed model.
Hi, I am running into exactly the same error and have trouble to load a finetuned model. I am wondering if @beyondbeneath ever found a solution to this?
Hi, I have the same issue. Any updates or recommendations how to fix it will be highly appreciated :)
Hello!
Firstly, thanks for this great work!
I managed to modify the AudioSet fine tuning script, and fine tuned a model on a new audio binary classification task. I started with the "Tiny" Patch model and used a batch size of 2. The resulting predictions on the evaluation set looked very promising!.
I'm now trying to write an inference script, to take that saved model to perform inferences, and running into some trouble. Which method do I actually need to call for pure inference? From the documentation it seems to describe only pre-training or fine-tuning, not inference.
More pressing, I can't actually get the model to load. I am trying to load the
best_audio_model.pth
as follows:however this results in the errors:
Is there anything obvious I'm missing or doing wrong? Would appreciate any guidance on how to load this model, and also perform an inference on a new
.wav
file. Thanks!