Open fuguanyu opened 1 year ago
Hi,
The net_swa
is created here to hold the moving average of the model during training.
If you want to load the model from the checkpoint you can do something like this :
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")}
net_swa_statedict = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")}
then:
modul.net.load_state_dict(net_statedict)
# or
modul.net.load_state_dict(net_swa_statedict)
for validation take a look at the validation step here
Hello, I would like to ask a few questions I see that the pre-trained models are all .pt files, and the model I trained without changing the default parameters is in the form of .ckpt. But it doesn't matter, when I use "passt_s_swa_p16_128_ap476" as a pre-training model to verify my fine-turn model, some problems arise: First of all, checkpoint saves another batch of parameters headed by net_swa., which may be related to the use of swa in the code, but the swa used in the introduction of the pre-training model is also used. Why is there no net_swa. parameter when printing the pre-training model, so when I load my own model, there is a problem of Unexpected key(s) in state_dict. I think it may be caused by this part of the code. How to solve this problem? In addition, I would like to ask, if a single piece of audio verifies my own model, how should the script be written?