jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.35k stars 222 forks source link

Calling TBPTT crashes when time_var=True and K!=False #191

Closed thebarnable closed 1 year ago

thebarnable commented 1 year ago

Description

I tried training N-MNIST using tonic with RTRL. Calling RTRL to perform an epoch immediately crashes because of an undefined variable (K_flag). On further inspection, RTRL calls TBPTT which crashes; the latter will always crash when time_var=True and K!=False (see below).

What I Did

% minimal_example.py

import snntorch as snn
import snntorch.functional as SF
from snntorch import utils
from snntorch import backprop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import tonic
import tonic.transforms as transforms

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=0.9, spike_grad='atan', init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=0.9, spike_grad='atan', init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=0.9, spike_grad='atan', init_hidden=True, output=True)
                    ).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss()

frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000), 
                                      transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, 
                                                         time_window=1000)
                                     ])
trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
trainloader = DataLoader(trainset, batch_size=128, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)

loss = backprop.RTRL(net, trainloader, optimizer=optimizer, criterion=loss_fn, time_var=True, device=device)

Calling python minimal_example.py crashes with:

Traceback (most recent call last):
  File "[...]minimal_example.py", line 39, in <module>
    loss = backprop.RTRL(net, trainloader, optimizer=optimizer, criterion=loss_fn, time_var=True, device=device)
  File "[...]/miniconda3/envs/snn/lib/python3.10/site-packages/snntorch/backprop.py", line 524, in RTRL
    return TBPTT(
  File "[...]/miniconda3/envs/snn/lib/python3.10/site-packages/snntorch/backprop.py", line 199, in TBPTT
    if K_flag is False:
UnboundLocalError: local variable 'K_flag' referenced before assignment

Possible fix

This error will only happen for time-varying data (time_var=True). Currently, in backprop.py::TBPTT, the following code is run for time_var=True:

if K is False:
  K_flag = K
if K_flag is False:
  K = num_steps

In case of RTRL (or anytime K is false), we don't run into the first condition. K_flag is not initialized elsewhere, so the second if condition throws the error. Simple fix: Initialize K_flag to True at the beginning of TBPTT.

jeshraghian commented 1 year ago

We've deprecated backprop.py as it was super inefficient and clunky to maintain. I'd recommend manually doing a weight update at each time-step. You may run into errors at first, but if you detach the state variables consistently, and re-initializing state variables using utils.reset(net), then it should work fine.

thebarnable commented 1 year ago

Fair enough, already went back to doing that. I'll close the issue then.