Open FIODDX opened 1 year ago
According to the log information, it seems that your loading state_dict is for the VGG-F model, while the model name in the test code is for the VGG model. Please check if you specified the model name as "-model online_spiking_vgg11f_ws" during training rather than the default "online_spiking_vgg11_ws". If yes, the test code should also specify this.
I train cifar10 ottt network with folowing codes:
After all 300 epochs, I got checkpoint_latest.pth and checkpoint_max.pth in my log dir. But when I run the folowing codes try to test the model:
I got error below:
RuntimeError: Error(s) in loading state_dict for OnlineSpikingVGG: Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.0.gain", "features.3.op.weight", "features.3.op.bias", "features.3.op.gain", "features.7.op.weight", "features.7.op.bias", "features.7.op.gain", "features.10.op.weight", "features.10.op.bias", "features.10.op.gain", "features.14.op.weight", "features.14.op.bias", "features.14.op.gain", "features.17.op.weight", "features.17.op.bias", "features.17.op.gain", "features.21.op.weight", "features.21.op.bias", "features.21.op.gain", "features.24.op.weight", "features.24.op.bias", "features.24.op.gain". Unexpected key(s) in state_dict: "conv1.weight", "conv1.bias", "conv1.gain", "fb_conv.op.weight", "fb_conv.op.bias", "features.2.op.weight", "features.2.op.bias", "features.2.op.gain", "features.6.op.weight", "features.6.op.bias", "features.6.op.gain", "features.9.op.weight", "features.9.op.bias", "features.9.op.gain", "features.13.op.weight", "features.13.op.bias", "features.13.op.gain", "features.16.op.weight", "features.16.op.bias", "features.16.op.gain", "features.20.op.weight", "features.20.op.bias", "features.20.op.gain", "features.23.op.weight", "features.23.op.bias", "features.23.op.gain".
Why the dict name in .pth file doesn't match my model? Seems they both use
net = spiking_vgg.__dict__[args.model](single_step_neuron=neuron.OnlineLIFNode, tau=args.tau, surrogate_function=surrogate.Sigmoid(), track_rate=True, c_in=3, num_classes=num_classes, neuron_dropout=args.drop_rate, grad_with_rate=True, fc_hw=1, v_reset=None)
to generate the network.