SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
531 stars 188 forks source link

Whitening fix - compute covariance matrix in float #3505

Closed JoeZiminski closed 2 weeks ago

JoeZiminski commented 4 weeks ago

In whiten.py there is a datatype argument to set the datatype of the output recording. The casting is performed after whitening has been performed and does not affect the datatype used for computing the whitening matrix. Up until 0.100.8, the covariance matrix was always computed in float32, but in 0.101.0 the line random_data = np.astype(random_data, np.float32) that was under this line was removed. Now if the recording is int16 the covariance matrix estimation is performed in int16 and will overflow in nearly all cases.

This PR reinstates that line, for now casting to float64. The recording dtype will be cast back to the user-specified dtype or the input recording dtype, so I don't think there is any downside in using maximum precision. The only thing is, maybe we will want to move to GPU support in future, and using float32 will keep it consistent.

Unless I'm missing something, I think it would be worth rolling this out in an update ASAP and making a note of this on the release note header. My understanding is the majority of recording datatypes are stored int16 and so this regression may catch out quite a lot of users since 0.101.0 and is quite hard to track down.

TODOs

zm711 commented 4 weeks ago

Out of curiosity since kilosort 1-3 (and even 4 prefers) int16 how do they deal with their whitening? I have looked.

JoeZiminski commented 4 weeks ago

I was curious about this also, they:

kilosort4

In kilosort4, the processing is done in float32, as mentioned here. The covariance matrix is computed in float32 (X is in float32 and torch.zeros initialises in float32 by default). You can see the rescaling by 200 and int16 conversion for example here.

Curiously, I can see in the data loading class here that if the data is uint16 it is cast to float. But in the class __init__ it states the default datatype expected is int16. I can't find ATM where this cast from int16 to float32 is performed (it is definitely performed for uint16). I need to look into this further, from the other comments and code sections the data is definitely expected to be in float32 for processing.

EDIT: Ah it is done here, the main purpose of the unit16 block is the centering around 0.

kilosort<4 (based on 2.5)

First, the whitening matrix is computed in float32 (from data processed by gpufilter.py which performs CAR and filtering in float32). Then float32 conversion + filtering and CAR is performed in the main preprocessing loop, as well as whitening. Then cast back to int16 is performed here, with the 200 scaling incorporated into the whitening matrix.

zm711 commented 4 weeks ago

Cool I guess they cast too. did you find the commit+PR where it was changed was there a comment about the dropping? Some of the preprocessing dtype stuff is a mystery to me so maybe someone had a reason?

JoeZiminski commented 3 weeks ago

Thanks @zm711 good point I was just looking at the version diff, I think it was here.

yger commented 3 weeks ago

Then indeed, this is a big mistake and random_data should always be broadcasted to float32 before estimating the whitening matrix

alejoe91 commented 3 weeks ago

@JoeZiminski are you planning to extend tests here or is it good to merge?

zm711 commented 3 weeks ago

Could we add the regression test in an issue to add? We use MS5 which relies on the whitening machinery. Would love to have this on main for everyone.

JoeZiminski commented 3 weeks ago

Hey @alejoe91 @zm711 hey both, sorry got waylaid with something but will add some tests today! agree this will be good to merge ASAP

zm711 commented 3 weeks ago

Sounds good to me. Thanks Joe!

JoeZiminski commented 3 weeks ago

Hey @alejoe91 @zm711 in the end adding test tests was taking longer than expected, I'll make a new PR for this. Feel free to merge! At the next meeting we can discuss where to include a warning (e.g. maybe 0.101.0 release notes, or the release notes for a new version).

In the end I reverted to np.float32 to a) match what was there before b) match other implementation (kilosort) c) match GPU if this is implemented. The overflowing computation is on int16 ephys data and is the dot product between two channel timeseries of length determined by the random chunk, by default it is 10k samples. In some impossible example, if somehow the raw data is all around + 32767, demean is not selected, and the user specifies 1M samples instead of 10k, the value will be 1.07 1e+15, float32 max value is 3.4 1e+39, so I think we'll be fine!