synsense / sinabs

A deep learning library for spiking neural networks which is based on PyTorch, focuses on fast training and supports inference on neuromorphic hardware.
https://sinabs.readthedocs.io
GNU Affero General Public License v3.0
77 stars 8 forks source link

Support tensors as input to hardware model #183

Open ssinhaleite opened 9 months ago

ssinhaleite commented 9 months ago

Just like most torch modules, DynapcnnNetwork takes a tensor as input to its forward method. That is, until you deploy it on hardware. From then on, the same model expects a list of events as input.

It would be convenient (and consistent) if the model were able to get tensors as input. The conversion to events could then happen inside the forward function. Ideally the model could accept either event lists or tensors and then convert if necessary.

bauerfe commented 7 months ago

I've been thinking a bit more about this. There are basically three input formats that we typically encounter:

1. Torch Tensors with binned (rasterized) events

The standard in Sinabs for software simulations. Converting this to events requires knowledge of the time bin size. If this is known, the ChipFactory can convert that to a list of events.

2. Structured numpy arrays (xytp)

Very common in many datasets. The ChipFactory can convert this to a list of events.

3. List of device events

The format that needs to be passed to samna and currently the only supported input format. I would argue that in most cases this type of data has been generated in sinabs by converting another format (either one of the two above or maybe events that have been previously recorded from chip, which are afaik not the same class)

I would suggest the following methods for a DynapcnnNetwork:

forward

Accepts all of the three formats and invokes a corresponding method to deal with it. The formats can be easily distinguished by data type.

forward_from_events

What forward is as of now - it accepts a list of events and passes them to the chip.

forward_from_xytp

Accepts a structured array of events, uses a chip factory to convert them to a list of events and then passes that to forward_from_events.

forward_from_tensor

Accepts a tensor of rasterized events, uses a chip factory to convert them to a list of events and then passes that to forward_from_events. In this case the time step of the input needs to be passed along somewhere. It could either be a parameter of that function, or an attribute of the DynapcnnNetwork. I would argue that it makes more sense to have that as a function parameter, because the timestep is only meaningful for the given data. Also, we avoid giving the impression that the chip is time discrete.

Looking forward to any feedback on this.

sheiksadique commented 7 months ago

Having the helper functions listed above would definitely be very useful to the users.

Having said that, the forward function directly accommodating any of the above formats of input is something I find inconsistent with pytorch API. I would instead suggest forward method remain as it is now and we provide the other explicit methods above for clarity or convenience.

Here is my thinking on this:

In pytorch, the equivalent would be if the forward function took any tensor and converted/moved the input to the appropriate device that the module is on. This would have been an equivalent and relatively simple change; but they choose not to do this automatically and leave it to the user to handle it.

I believe this is partly because of the additional complexity of the forward function to check the data format and convert in each forward call that would lead to slower processing. Instead if this is handled in parallel in the dataset/dataloader, it both reduces the complexity of the forward function and might also contribute to faster training and testing time.

What do you guys think?

bauerfe commented 7 months ago

That's a good point. So far I have considered consistency as the forward method always taking the same (data-)type as argument. But it's true that consistent (in particular consistent with PyTorch) can also mean to take device-specific data as input, depending on where the model is deployed, and not do any conversion under the hood.

For the sake of simplicity and clarity, let's do it as you suggest and add the explicit methods, while keeping forward as it is.

Something I hadn't considered so far: Should these new methods fail when the model is not yet deployed on event-based hw?