synsense / sinabs

A deep learning library for spiking neural networks which is based on PyTorch, focuses on fast training and supports inference on neuromorphic hardware.
https://sinabs.readthedocs.io
GNU Affero General Public License v3.0
77 stars 8 forks source link

Question: Where does the weight transfer happen? #207

Closed cowolff closed 7 months ago

cowolff commented 8 months ago

Quick question regarding this tutorial: https://sinabs.readthedocs.io/en/v1.2.10/tutorials/weight_transfer_mnist.html I understand that the model conversion happens in this line:

sinabs_model = from_model(
    ann, input_shape=input_shape, add_spiking_output=True, synops=False, num_timesteps=num_timesteps
)

What I don't understand is, where does the weight conversion from the ANN to the SNN actually happen? As far as I understand those 3 files from_torch.py, conversion.py and network.py the model is first recreated in the same "shape" as the original model, just with spiking layers and then handed over to the network class in combination with the old model. This is then returned to the user. Still, I don't understand, where the weights from the ANN are actually transferred to the SNN?

bauerfe commented 7 months ago

Hi @cowolff , thank you for bringing up this question.

Your understanding of how the conversion from ANN to SNN works is pretty good. Let me try to summarize what's happening:

First, the original model is copied, including all layers, weights, and all other parameters. After that the activation layers (e.g. ReLU) are replaced by spiking layers.

So the weights are copied to the new model 1:1 - their shape and values stay exactly the same.

This means that you might have to scale weights or firing thresholds to make sure that the spiking activity is in the desired range. You can read more about this here: https://sinabs.readthedocs.io/en/v1.2.1/tutorials/weight_scaling.html

cowolff commented 7 months ago

Thanks for your reply! But I am still wondering where this copying of the weights is happening in the code of the library? :)

bauerfe commented 7 months ago

In this line, from_module calls the replace_module function (defined here, which first makes a deepcopy of the network and then replaces the ReLU layers.