I wanted to load a model that I trained using snnTorch. I didnt had any problems before. For this reason I investigated a bit further and I realized that loaded_model.load_state_dict(torch.load("snn_model.pth")) will throw an error for snntorch version 0.8. but not for 0.7..
According to Prof. Eshraghian: "There have been some changes to refactor/simplify neuron models, and also to make it compatible with torch.compile()." and this could be the source of the problem.
What I Did
In case the file doesnt load, this is an out of the box script you can run to reproduce the error:
import torch
import torch.nn as nn
import snntorch as snn
class SNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(input_size, hidden_size)
self.lif1 = snn.Leaky(beta=0.7)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.lif2 = snn.Leaky(beta=0.7)
self.fc3 = nn.Linear(hidden_size, output_size)
self.lif3 = snn.Leaky(beta=0.7)
def forward(self, inpt):
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
mem3 = self.lif3.init_leaky()
spike3_rec = []
mem3_rec = []
for step in range(inpt.shape[0]):
current_input = inpt[step]
current_input = self.flatten(current_input)
current1 = self.fc1(current_input)
spike1, mem1 = self.lif1(current1, mem1)
current2 = self.fc2(spike1)
spike2, mem2 = self.lif2(current2, mem2)
current3 = self.fc3(spike2)
spike3, mem3 = self.lif3(current3, mem3)
spike3_rec.append(spike3)
mem3_rec.append(mem3)
return torch.stack(spike3_rec, dim=0), torch.stack(mem3_rec, dim=0)
# Dummy data
input_size = 10
hidden_size = 20
output_size = 2
batch_size = 5
time_steps = 50
inputs = torch.randn(time_steps, batch_size, input_size)
targets = torch.randint(0, output_size, (batch_size,))
model = SNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1):
optimizer.zero_grad()
spikes, memories = model(inputs)
# Compute loss on the last time step's output spikes for simplicity
loss = criterion(spikes[-1], targets)
loss.backward()
optimizer.step()
print(f"Training loss: {loss.item()}")
# Save the model
torch.save(model.state_dict(), "snn_model.pth")
# Load the model
loaded_model = SNN(input_size, hidden_size, output_size)
# here is where the error should be produced
loaded_model.load_state_dict(torch.load("snn_model.pth"))
print("Model loaded successfully.")
The Error for v0.8.*
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], [line 77](vscode-notebook-cell:?execution_count=2&line=77)
[74](vscode-notebook-cell:?execution_count=2&line=74) loaded_model = SNN(input_size, hidden_size, output_size)
[76](vscode-notebook-cell:?execution_count=2&line=76) # here is where the error should be produced
---> [77](vscode-notebook-cell:?execution_count=2&line=77) loaded_model.load_state_dict(torch.load("snn_model.pth"))
[79](vscode-notebook-cell:?execution_count=2&line=79) print("Model loaded successfully.")
File [~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153), in Module.load_state_dict(self, state_dict, strict, assign)
[2148](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2148) error_msgs.insert(
[2149](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2149) 0, 'Missing key(s) in state_dict: {}. '.format(
[2150](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2150) ', '.join(f'"{k}"' for k in missing_keys)))
[2152](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2152) if len(error_msgs) > 0:
-> [2153](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2153) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[2154](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2154) self.__class__.__name__, "\n\t".join(error_msgs)))
[2155](https://vscode-remote+wsl-002bubuntu-002d20-002e04.vscode-resource.vscode-cdn.net/home/copparihollmann/neuroTUM/SpikingC/SpikingCpp/notebooks/~/miniconda3/envs/neuromorphic/lib/python3.12/site-packages/torch/nn/modules/module.py:2155) return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SNN:
size mismatch for lif1.mem: copying a param with shape torch.Size([5, 20]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for lif2.mem: copying a param with shape torch.Size([5, 20]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for lif3.mem: copying a param with shape torch.Size([5, 2]) from checkpoint, the shape in current model is torch.Size([1]).
reproduce_error_load_snntorch.txt
Description
I wanted to load a model that I trained using snnTorch. I didnt had any problems before. For this reason I investigated a bit further and I realized that
loaded_model.load_state_dict(torch.load("snn_model.pth"))
will throw an error for snntorch version 0.8. but not for 0.7..According to Prof. Eshraghian: "There have been some changes to refactor/simplify neuron models, and also to make it compatible with torch.compile()." and this could be the source of the problem.
What I Did
In case the file doesnt load, this is an out of the box script you can run to reproduce the error:
The Error for v0.8.*