Open asudeeaydin opened 2 years ago
Yes, we save the states for saving memory in the training process. In the inference step, you can input empty tensors instead.
How does this process save memory exactly? When I look at the MemoryModule in base, I see that the states are being detached but only for the last layer.
My thought was maybe you were training with truncated BPTT but for that, you would need to detach the entire network at a certain time step and not only the last layer?
All the states are detached during training. Since we use continuous voxel to train the network, the state of the last layer is further used to compute loss.
Quick question that I couldn't be sure about:
On line
407
ofmodel.snn_network
, why is the previous membrane potential passed in? Is this to just save the states (ie. membrane potentials) of the neurons in the last layer (output layer) for further analysis? If so, how do you access these states later on?