jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.28k stars 217 forks source link

Add support for `torch.compile(fullgraph=True)` WIP #271

Closed gekkom closed 8 months ago

gekkom commented 9 months ago

These are experimental changes that add support for torch.compile(fullgraph=True) (please note only the default mode is supported right now so cudagraphs probably wont work)

Currently only Leaky is supported and changes can be tested in examples/tutorial_5_fullgraph.ipynb

jeshraghian commented 9 months ago

At a quick glance, this is in nice shape. I'll think further on how to handle the specification of the hidden layer size in arguments, as this will break backwards compatibility. At the moment, the size of the layer is inferred in a similar way to how this is done with batch sizes. It might need to be an optional argument.

gekkom commented 9 months ago

I removed the layer_size requirement.

The only breaking change so far ~will be that spike_grad needs to store a class dynamically if the method is outside of nn.Module instead of a method to get around this issue with pytorch. https://github.com/pytorch/pytorch/issues/112670 This shouldn't be a big issue, but needs to be mentioned in the docs somewhere.

To be clear if a user wants a custom spike_grad method, it should be created within the nn.Module, the only case where a custom class needs to be and should be used is in the case of custom Autograd methods.

jeshraghian commented 8 months ago

Sorry for how long it's taken me to get around to this.

I fixed a conflict to enable the tests to run, though the absence of the init_leaky functions is causing tests to fail. So a bit of a refactoring of tests might be needed.

I'll test your PR on a few different corner cases of models, if all of those work, then I'll go ahead and modify tests to be suited for your approach.

jeshraghian commented 8 months ago

The example in examples/tutorial_5_fullgraph.ipynb hit me with some torch.compile()-related errors when running on cuda, but works just fine on CPU.

Otherwise, all my training-loop tests across CPU/GPU/different architectures are working.

Just have to see what's happening test_lapicque.py, test_leaky.py, and that should be it.

jeshraghian commented 8 months ago

lapicque.py had a dependency on the init_leaky function, so I've fixed it strictly for the lapicque neuron. All but one test is now passing.

The only other error I see is in test_leaky_cases in test_leaky.py, where a leaky neuron is instantiated with init_hidden=True.

My intention with this test was to show the case where two inputs are fed to the neuron (i.e., current input, and membrane potential), but init_hidden=True, such that the membrane potential shouldn't be explicitly fed as an argument of the forward-pass.

At the moment, the external membrane potential overrides the instance variable. Testing a fix now.

jeshraghian commented 8 months ago

All tests are passing! I think this is good for merging. It might be good to keep this open here so other neurons can be refactored into the same style.