NeuroBench / neurobench

Benchmark harness and baseline results for the NeuroBench algorithm track.
https://neurobench.readthedocs.io
Apache License 2.0
52 stars 12 forks source link

Output encoding of regression model for primate_reaching dataset #224

Closed JannKrausse closed 1 month ago

JannKrausse commented 1 month ago

Hi *,

Currently, my colleague and I are participating in the IEEE BioCAS x Neurobench Challenge for the Primate Reaching dataset. We are unsure about how the network should compute its output. I also do not have a ton of experience with regression tasks. Hence, I hope my questions do make sense.

We studied the example network given in neurobench/examples/primate_reaching/SNN_3.py very carefully. As we understood, given a time step in a sequence, it will take the last self.bin_window_size input frames (if possible) and predict the cursor position (x, y) on this specific time step. It will iteratively do so for all existing time steps in the sequence to create a sequence of predictions as the output.

What one could instead do, is simply take a sequence and predict the cursor position at the last time step. This is technically what SNN_3 does, just once without stacking predictions to create an output sequence. Assuming the network consists of leaky units in the output, we think that this should not make much of a difference to the example given in the repository during deployment on an infinite time series (reality). There, the internal neuronal states at passed time steps beyond the length of the sequences used during training will have "leaked away". From what we saw, this approach significantly speeds up the training.

Still, the key difference is that one network outputs a sequence of predictions, and the other does not. Also, the calculation of R2 differs across both cases. For SNN_3 it is calculated on an actual sequence of predictions while for the proposed approach it would be calculated on unrelated predictions as part of a randomly composed batch of input sequences. Does the value of R2 still make sense in that case?

Our problem is, that the challenge description does not state whether the prediction of the network should be a sequence or not. Are we missing any key idea or consideration of SNN_3 that makes the second approach nonsensical?

Thanks a lot for any help! Jann

jasonlyik commented 1 month ago

Hi Jann,

The inference task is seq2seq, which is to generate a full sequence of X,Y velocities given the sequence of spike data.

I don't understand your suggestion fully, but to me it sounds similar to how the SNN2 baseline was trained. Training notes are written in our preprint on page 23: https://arxiv.org/pdf/2304.04640. To summarize, I believe SNN2 trains on batches of sequences of length 50 (200ms), generating a sequence of length 50. The MSE loss is linearly weighted across these points, such that the later predictions are heavier weighted than the earlier predictions. We found that this training scheme translates well to the infinite-series case during inference.

So, the training approach does not necessarily need to match the inference setting, and we would consider training technique as an important axis for "best" solutions.

Let me know if this helps, I can also send you training code for SNN2, though it is very rough since it was developed at the same time as the overall harness.

JannKrausse commented 1 month ago

Hi Jason,

thanks for clarifying, that helped a lot!