fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.4k stars 244 forks source link

Online Training Through Time (OTTT) Training Function Error #593

Open tsumme1 opened 3 days ago

tsumme1 commented 3 days ago

Issue type

SpikingJelly version

0.0.0.0.15

Description

When trying to train a network with OTTT, the function runs into an error, seemingly because the neuron state is not detached from the computational graph after each time-step. Adding functional.detach_net(model) inside functional.ottt_online_training eliminated the error for me.

Code to reproduce the error

import torch
from torch import nn
from torch.nn import functional as F

from spikingjelly.activation_based import neuron, layer, functional

net = layer.OTTTSequential(
    nn.Linear(8, 4),
    neuron.OTTTLIFNode(),
    nn.Linear(4, 2),
    neuron.LIFNode()
)

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

T = 4
N = 2
online = True
for epoch in range(2):

    x_seq = torch.rand([N, T, 8])
    target_seq = torch.rand([N, T, 2])

    functional.ottt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online)
    functional.reset_net(net)

Error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.