jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.37k stars 228 forks source link

Bug: Porting model another device (CPU/GPU) #274

Open satabios opened 11 months ago

satabios commented 11 months ago

Description

When I load a pre-trained model and push it from GPU->CPU or the vice versa. Certain variables are still projecting in the original device and are not getting pushed to the destination device.

Specifically the "mem" variable under snntorch._neurons.leaky.Leaky

snn_pretrained_model_path = "snn_model.pth"
snn_model.load_state_dict(torch.load(snn_pretrained_model_path))  
snn_model.to("cpu") # or "gpu"

What I Did

As a workaround, I deliberately iterate over all the model to find such instances of leaky and push them to the destination device.

if isinstance(layer, nn.Sequential):
    for layer_id in range(len(original_dense_model)):
        layer = original_dense_model[layer_id]
        if isinstance((layer), snntorch._neurons.leaky.Leaky):
            layer.mem = layer.mem.to("cpu") # or "gpu" depending on the destination device
else:
    for internal_layer in model.modules():
        if isinstance((internal_layer), snntorch._neurons.leaky.Leaky):
            internal_layer.mem = internal_layer.mem.to("cpu")
ahenkes1 commented 10 months ago

@satabios , have you tried the following steps? Saving and loading models across devices in PyTorch

satabios commented 10 months ago

I followed Saving and loading models across devices in PyTorch.

The issue isn't during saving or reloading. However, it is with porting from one device to another, as mentioned above, Certain members of the model aren't getting transferred intrinsically.

ahenkes1 commented 10 months ago

So the issue is there with newly created models? Without the saving/loading.

satabios commented 10 months ago

Either way, the bug persists. When the model is in memory and is prompted to transfer to a different device (say from GPU to CPU or vice versa) or when the model is loaded from a file. The porting causes the variables to be struck in the original device.