fangwei123456 / Spike-Element-Wise-ResNet

Deep Residual Learning in Spiking Neural Networks
Mozilla Public License 2.0
140 stars 21 forks source link

How to use pretrained weights? #6

Closed mountains-high closed 2 years ago

mountains-high commented 2 years ago

As weights were saved via state_dict

import spiking_resnet
net = spiking_resnet.spiking_resnet34()
net.load_state_dict(torch.load("spiking_resnet_34_checkpoint_319.pth")['model'])
  1. to use anaconda virtual environment to avoid the permission and dependency problems also to have this specific version of Spikingjelly in this env.

  2. git clone https://github.com/fangwei123456/spikingjelly.git

  3. cd spikingjelly

  4. git reset --hard 2958519df84ad77c316c6e6fbfac96fb2e5f59a3 #Because it was made on older version of spikingjelly

  5. python setup.py install

If after running the above codes, you don't get any errors, then that's fine. And you must run the above scripts if you want to train from scratch or reproduce the code.

If you get the following Error message:

  from spikingjelly.cext import neuron as cext_neuron
  File "/home/Desktop/SEW_2/spikingjelly/spikingjelly/cext/neuron.py", line 5, in <module>
    import _C_neuron
ModuleNotFoundError: No module named '_C_neuron'

Then you should do the following steps:

  1. to replace cext_neuron.MultiStepIFNode into:
#from spikingjelly.cext import neuron as cext_neuron
from spikingjelly.clock_driven import neuron, layer, surrogate
#self.sn1 = cext_neuron.MultiStepIFNode(detach_reset=True)
self.sn1 = layer.MultiStepContainer(neuron.IFNode(detach_reset=True, surrogate_function=surrogate.ATan()))

Because I think "git clone"(setup.py) installs without CUDA Extension by default.

CUDA_HOME is None. Install Without CUDA Extension
running install
running bdist_egg
running egg_info
creating spikingjelly.egg-info
writing spikingjelly.egg-info/PKG-INFO
writing dependency_links to spikingjelly.egg-info/dependency_links.txt
......

.....
Using /home/anaconda3/envs/sew/lib/python3.9/site-packages/typing_extensions-4.1.1-py3.9.egg
Searching for six==1.16.0
Best match: six 1.16.0
Processing six-1.16.0-py3.9.egg
six 1.16.0 is already the active version in easy-install.pth

Using /home/anaconda3/envs/sew/lib/python3.9/site-packages/six-1.16.0-py3.9.egg
Finished processing dependencies for spikingjelly==0.0.0.0.4
  1. to make changes in surrogate.py
    #self.register_buffer('alpha', torch.tensor(alpha, dtype=torch.float)) 
    self.alpha = alpha

If you follow these steps you will be able to use pre-trained weights. Thank you, Dr. Wei for your help and time.

I'm making this in order to save others time and Dr. Wei's time in the future:) Hope this helps.

P.S: if weights were saved as

checkpoint = torch.save({
    'net': model_without_ddp,
    'model': model_without_ddp.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
        'args': args,
        'max_test_acc1': max_test_acc1,
        }, check_point_max_path)
net = torch.load("spiking_resnet_34_checkpoint_319.pth")['net']

Wouldn't we have the above issues? Thank you

fangwei123456 commented 2 years ago

Thanks! I think it will be helpful for others!