Closed gekkom closed 8 months ago
At a quick glance, this is in nice shape. I'll think further on how to handle the specification of the hidden layer size in arguments, as this will break backwards compatibility. At the moment, the size of the layer is inferred in a similar way to how this is done with batch sizes. It might need to be an optional argument.
I removed the layer_size requirement.
The only breaking change so far ~will be that spike_grad
needs to store a class dynamically if the method is outside of nn.Module
instead of a method to get around this issue with pytorch. https://github.com/pytorch/pytorch/issues/112670
This shouldn't be a big issue, but needs to be mentioned in the docs somewhere.
To be clear if a user wants a custom spike_grad method, it should be created within the nn.Module, the only case where a custom class needs to be and should be used is in the case of custom Autograd methods.
Sorry for how long it's taken me to get around to this.
I fixed a conflict to enable the tests to run, though the absence of the init_leaky
functions is causing tests to fail.
So a bit of a refactoring of tests might be needed.
I'll test your PR on a few different corner cases of models, if all of those work, then I'll go ahead and modify tests to be suited for your approach.
The example in examples/tutorial_5_fullgraph.ipynb
hit me with some torch.compile()
-related errors when running on cuda, but works just fine on CPU.
Otherwise, all my training-loop tests across CPU/GPU/different architectures are working.
Just have to see what's happening test_lapicque.py
, test_leaky.py
, and that should be it.
lapicque.py
had a dependency on the init_leaky
function, so I've fixed it strictly for the lapicque neuron. All but one test is now passing.
The only other error I see is in test_leaky_cases
in test_leaky.py
, where a leaky neuron is instantiated with init_hidden=True
.
My intention with this test was to show the case where two inputs are fed to the neuron (i.e., current input, and membrane potential), but init_hidden=True
, such that the membrane potential shouldn't be explicitly fed as an argument of the forward-pass.
At the moment, the external membrane potential overrides the instance variable. Testing a fix now.
All tests are passing! I think this is good for merging. It might be good to keep this open here so other neurons can be refactored into the same style.
These are experimental changes that add support for
torch.compile(fullgraph=True)
(please note only the default mode is supported right now so cudagraphs probably wont work)Currently only Leaky is supported and changes can be tested in
examples/tutorial_5_fullgraph.ipynb