pkuxmq / OTTT-SNN

[NeurIPS 2022] Online Training Through Time for Spiking Neural Networks
51 stars 10 forks source link

can't reload the cifar10 model #1

Open FIODDX opened 1 year ago

FIODDX commented 1 year ago

I train cifar10 ottt network with folowing codes:

python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0

After all 300 epochs, I got checkpoint_latest.pth and checkpoint_max.pth in my log dir. But when I run the folowing codes try to test the model:

python get_rate_cifar.py -data_dir path_to_data_dir -dataset cifar10 -gpu-id 0 -resume path_to_checkpoint

I got error below:

RuntimeError: Error(s) in loading state_dict for OnlineSpikingVGG: Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.0.gain", "features.3.op.weight", "features.3.op.bias", "features.3.op.gain", "features.7.op.weight", "features.7.op.bias", "features.7.op.gain", "features.10.op.weight", "features.10.op.bias", "features.10.op.gain", "features.14.op.weight", "features.14.op.bias", "features.14.op.gain", "features.17.op.weight", "features.17.op.bias", "features.17.op.gain", "features.21.op.weight", "features.21.op.bias", "features.21.op.gain", "features.24.op.weight", "features.24.op.bias", "features.24.op.gain". Unexpected key(s) in state_dict: "conv1.weight", "conv1.bias", "conv1.gain", "fb_conv.op.weight", "fb_conv.op.bias", "features.2.op.weight", "features.2.op.bias", "features.2.op.gain", "features.6.op.weight", "features.6.op.bias", "features.6.op.gain", "features.9.op.weight", "features.9.op.bias", "features.9.op.gain", "features.13.op.weight", "features.13.op.bias", "features.13.op.gain", "features.16.op.weight", "features.16.op.bias", "features.16.op.gain", "features.20.op.weight", "features.20.op.bias", "features.20.op.gain", "features.23.op.weight", "features.23.op.bias", "features.23.op.gain".

Why the dict name in .pth file doesn't match my model? Seems they both use net = spiking_vgg.__dict__[args.model](single_step_neuron=neuron.OnlineLIFNode, tau=args.tau, surrogate_function=surrogate.Sigmoid(), track_rate=True, c_in=3, num_classes=num_classes, neuron_dropout=args.drop_rate, grad_with_rate=True, fc_hw=1, v_reset=None)to generate the network.

pkuxmq commented 1 year ago

According to the log information, it seems that your loading state_dict is for the VGG-F model, while the model name in the test code is for the VGG model. Please check if you specified the model name as "-model online_spiking_vgg11f_ws" during training rather than the default "online_spiking_vgg11_ws". If yes, the test code should also specify this.