KordingLab / Neural_Decoding

A python package that includes many methods for decoding neural activity
BSD 3-Clause "New" or "Revised" License
429 stars 119 forks source link

Correct formatting input data for SimpleRNNClassification #11

Closed horsto closed 4 years ago

horsto commented 4 years ago

I have a quick question about the right formatting of input for the decoders.SimpleRNNClassification (and similar classes). In the docstring I read:

X_train: numpy 3d array of shape [n_samples,n_time_bins,n_neurons]
    This is the neural data.
    See example file for an example of how to format the neural data correctly

But in your usage example notebook (central_concepts_in_ML_for_decoding.ipynb) the input X_train is formatted as [n_samples,n_neurons,n_time_bins]. Which one is correct?

jglaser2 commented 4 years ago

It is correct in the docstring.

Thanks for bringing this up - I realize that this is an error in that example notebook that we will need to fix. In the example notebook, you can just use np.swapaxes to put the data in the correct format for the SimpleRNNClassification model. This will give you better accuracy than in the current notebook with the error.

horsto commented 4 years ago

Thanks, great.