jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.35k stars 222 forks source link

Backprop issue with membrane potential reset in PPO #340

Open ADebor opened 3 months ago

ADebor commented 3 months ago

Description

Hi there,

I'm trying to implement a basic RL training loop for spiking nets using snntorch and torchrl. In this Actor Critic setting, the actor is an SNN made of 3 parts: a population encoder, a spiking MLP, and a population decoder. The critic is a non-spiking ANN.

For context, the PPO algo comprises a rollout phase during which one gathers data from the environment, followed by an update phase during which the actor and critics are updated.

In the encoder and the spiking MLP, I use Leaky neurons to generate and process spikes. I initialize these neurons with init_hidden set to True. The encoder and the MLP are two different nn.Modules, each defining its own forward method. In each of these methods, I used to call utils.reset(self.net) prior to any processing, in which net is an nn.Sequential (resp. containing one Leaky neuron, and multiple Leaky neurons and linear layers).

This goes fine during the rollout phase (at least, it runs without throwing any errors). However, problems arise when in the update phase: the update loop crashes when trying to backward for the second time (i.e. using the second mini-batch from the data collected during the first rollout). I get the [RuntimeError: Trying to backward through the graph a second time (...)] error.

After digging in a bit, I noticed that this seems to come from the mem variable of the Leaky neuron(s). During the first loop iteration, the batch size changes (changing from "number of parallel environments" to "minibatch size"), and the self.mem variable is assigned a new tensor in the Leaky's forward method. For the second iteration though, as the batch size is the same as for the first iteration, this does not happen. I thought that calling the utils.reset(self.net) would have the same effect as assigning a new tensor, but this is not what I observed. Actually, the mem tensor which is manipulated in the reset_hidden class method does not seem to be the same as the one used in forward (batch size is always equal to "number of parallel environments" in reset_hidden while I'd expect it to change after the first training iteration). The fact that the same mem is used in two iterations seems to cause the issue with the backprop.

What I Did

I'm not getting errors if I change utils.reset(self.net) for net.a_leaky_neuron.reset_mem() in my modules' forward methods. Not sure the training is done properly though, I'm only trying to have something running without errors at the moment.

I might be wrong in my way of using snntorch utils, but could you tell me if this problem rings a bell on your side? Do you maybe see something not correct in the way I reset neurons? Why does calling the utils method does not work?

If needed, I could share some code but it is quite bulky at the moment and it would require me a bit of work to provide a minimal example.

Thanks a lot!