jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.17k stars 202 forks source link

Error using load_state_dict for snntorch v0.8.* #314

Closed copparihollmann closed 2 months ago

copparihollmann commented 2 months ago

reproduce_error_load_snntorch.txt

Description

I wanted to load a model that I trained using snnTorch. I didnt had any problems before. For this reason I investigated a bit further and I realized that loaded_model.load_state_dict(torch.load("snn_model.pth")) will throw an error for snntorch version 0.8. but not for 0.7..

According to Prof. Eshraghian: "There have been some changes to refactor/simplify neuron models, and also to make it compatible with torch.compile()." and this could be the source of the problem.

What I Did

In case the file doesnt load, this is an out of the box script you can run to reproduce the error:

import torch
import torch.nn as nn
import snntorch as snn

class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNN, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.7)

        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.lif2 = snn.Leaky(beta=0.7)

        self.fc3 = nn.Linear(hidden_size, output_size)
        self.lif3 = snn.Leaky(beta=0.7)

    def forward(self, inpt):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spike3_rec = []
        mem3_rec = []

        for step in range(inpt.shape[0]):
            current_input = inpt[step]
            current_input = self.flatten(current_input)

            current1 = self.fc1(current_input)
            spike1, mem1 = self.lif1(current1, mem1)

            current2 = self.fc2(spike1)
            spike2, mem2 = self.lif2(current2, mem2)

            current3 = self.fc3(spike2)
            spike3, mem3 = self.lif3(current3, mem3)

            spike3_rec.append(spike3)
            mem3_rec.append(mem3)

        return torch.stack(spike3_rec, dim=0), torch.stack(mem3_rec, dim=0)

# Dummy data
input_size = 10
hidden_size = 20
output_size = 2
batch_size = 5
time_steps = 50

inputs = torch.randn(time_steps, batch_size, input_size)
targets = torch.randint(0, output_size, (batch_size,))

model = SNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1):
    optimizer.zero_grad()
    spikes, memories = model(inputs)
    # Compute loss on the last time step's output spikes for simplicity
    loss = criterion(spikes[-1], targets)
    loss.backward()
    optimizer.step()

print(f"Training loss: {loss.item()}")

# Save the model
torch.save(model.state_dict(), "snn_model.pth")

# Load the model
loaded_model = SNN(input_size, hidden_size, output_size)

# here is where the error should be produced
loaded_model.load_state_dict(torch.load("snn_model.pth"))

print("Model loaded successfully.")

The Error for v0.8.*

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], [line 77](vscode-notebook-cell:?execution_count=2&line=77)
     [74](vscode-notebook-cell:?execution_count=2&line=74) loaded_model = SNN(input_size, hidden_size, output_size)
     [76](vscode-notebook-cell:?execution_count=2&line=76) # here is where the error should be produced
---> [77](vscode-notebook-cell:?execution_count=2&line=77) loaded_model.load_state_dict(torch.load("snn_model.pth"))
     [79](vscode-notebook-cell:?execution_count=2&line=79) print("Model loaded successfully.")

File [~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153), in Module.load_state_dict(self, state_dict, strict, assign)
   [2148](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2148)         error_msgs.insert(
   [2149](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2149)             0, 'Missing key(s) in state_dict: {}. '.format(
   [2150](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2150)                 ', '.join(f'"{k}"' for k in missing_keys)))
   [2152](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2152) if len(error_msgs) > 0:
-> [2153](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153)     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   [2154](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2154)                        self.__class__.__name__, "\n\t".join(error_msgs)))
   [2155](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2155) return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SNN:
    size mismatch for lif1.mem: copying a param with shape torch.Size([5, 20]) from checkpoint, the shape in current model is torch.Size([1]).
    size mismatch for lif2.mem: copying a param with shape torch.Size([5, 20]) from checkpoint, the shape in current model is torch.Size([1]).
    size mismatch for lif3.mem: copying a param with shape torch.Size([5, 2]) from checkpoint, the shape in current model is torch.Size([1]).
stevenabreu7 commented 2 months ago

Running into the same issue! Have you made any progress on this?