kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
295 stars 50 forks source link

test my own model #36

Open fuguanyu opened 1 year ago

fuguanyu commented 1 year ago

Hello, I would like to ask a few questions I see that the pre-trained models are all .pt files, and the model I trained without changing the default parameters is in the form of .ckpt. But it doesn't matter, when I use "passt_s_swa_p16_128_ap476" as a pre-training model to verify my fine-turn model, some problems arise: First of all, checkpoint saves another batch of parameters headed by net_swa., which may be related to the use of swa in the code, but the swa used in the introduction of the pre-training model is also used. Why is there no net_swa. parameter when printing the pre-training model, so when I load my own model, there is a problem of Unexpected key(s) in state_dict. I think it may be caused by this part of the code. How to solve this problem? image image In addition, I would like to ask, if a single piece of audio verifies my own model, how should the script be written?

kkoutini commented 1 year ago

Hi, The net_swa is created here to hold the moving average of the model during training. If you want to load the model from the checkpoint you can do something like this :

ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")}
net_swa_statedict  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")}

then:

modul.net.load_state_dict(net_statedict)
# or
modul.net.load_state_dict(net_swa_statedict)

for validation take a look at the validation step here