When trying to train a network with OTTT, the function runs into an error, seemingly because the neuron state is not detached from the computational graph after each time-step. Adding functional.detach_net(model) inside functional.ottt_online_training eliminated the error for me.
Code to reproduce the error
import torch
from torch import nn
from torch.nn import functional as F
from spikingjelly.activation_based import neuron, layer, functional
net = layer.OTTTSequential(
nn.Linear(8, 4),
neuron.OTTTLIFNode(),
nn.Linear(4, 2),
neuron.LIFNode()
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
T = 4
N = 2
online = True
for epoch in range(2):
x_seq = torch.rand([N, T, 8])
target_seq = torch.rand([N, T, 2])
functional.ottt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online)
functional.reset_net(net)
Error
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Issue type
SpikingJelly version
0.0.0.0.15
Description
When trying to train a network with OTTT, the function runs into an error, seemingly because the neuron state is not detached from the computational graph after each time-step. Adding
functional.detach_net(model)
insidefunctional.ottt_online_training
eliminated the error for me.Code to reproduce the error
Error