fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

calling spikingjelly functional.reset_net() in pytorch lightning module #288

Closed gwgknudayanga closed 1 year ago

gwgknudayanga commented 1 year ago

Hi, Within a pytorch lightning module i am calling the functional.reset_net(model) within the training_step() as follows. Is that fine? because at the moment of calling reset_net() the backward pass for this step has not happened as i understood.

from spikingjelly import functional

class myTrainModule (pl.LightningModule): def init(self,model): self.model = model # here model is nn.Sequential with conv layers and spiking neurons

def training_step(): loss = crossentrophy(output,target)

functional.reset_net(self.model) return loss

fangwei123456 commented 1 year ago

Hi, I do not know much about pytorch lightning. But calling reset before backward will not cause error. For example, after reset, v in neuron will be set to 0. But the tensors that backward needs have been stored in auto grad functions. (If not, the auto grad engineer will raise an in-place error.)

gwgknudayanga commented 1 year ago

Thank you for the explanation. !