Closed JoeZiminski closed 2 weeks ago
Out of curiosity since kilosort 1-3 (and even 4 prefers) int16 how do they deal with their whitening? I have looked.
I was curious about this also, they:
int16
raw datafloat32
for CAR, filtering, whiteningint16
.kilosort4
In kilosort4, the processing is done in float32
, as mentioned here. The covariance matrix is computed in float32
(X
is in float32
and torch.zeros
initialises in float32
by default). You can see the rescaling by 200 and int16
conversion for example here.
Curiously, I can see in the data loading class here that if the data is uint16
it is cast to float. But in the class __init__
it states the default datatype expected is int16
. I can't find ATM where this cast from int16
to float32
is performed (it is definitely performed for uint16
). I need to look into this further, from the other comments and code sections the data is definitely expected to be in float32
for processing.
EDIT: Ah it is done here, the main purpose of the unit16
block is the centering around 0.
kilosort<4 (based on 2.5)
First, the whitening matrix is computed in float32
(from data processed by gpufilter.py
which performs CAR and filtering in float32
). Then float32 conversion + filtering and CAR is performed in the main preprocessing loop, as well as whitening. Then cast back to int16 is performed here, with the 200 scaling incorporated into the whitening matrix.
Cool I guess they cast too. did you find the commit+PR where it was changed was there a comment about the dropping? Some of the preprocessing dtype stuff is a mystery to me so maybe someone had a reason?
Thanks @zm711 good point I was just looking at the version diff, I think it was here.
Then indeed, this is a big mistake and random_data should always be broadcasted to float32 before estimating the whitening matrix
@JoeZiminski are you planning to extend tests here or is it good to merge?
Could we add the regression test in an issue to add? We use MS5 which relies on the whitening machinery. Would love to have this on main for everyone.
Hey @alejoe91 @zm711 hey both, sorry got waylaid with something but will add some tests today! agree this will be good to merge ASAP
Sounds good to me. Thanks Joe!
Hey @alejoe91 @zm711 in the end adding test tests was taking longer than expected, I'll make a new PR for this. Feel free to merge! At the next meeting we can discuss where to include a warning (e.g. maybe 0.101.0 release notes, or the release notes for a new version).
In the end I reverted to np.float32
to a) match what was there before b) match other implementation (kilosort) c) match GPU if this is implemented. The overflowing computation is on int16 ephys data and is the dot product between two channel timeseries of length determined by the random chunk, by default it is 10k samples. In some impossible example, if somehow the raw data is all around + 32767, demean is not selected, and the user specifies 1M samples instead of 10k, the value will be 1.07 1e+15, float32 max value is 3.4 1e+39, so I think we'll be fine!
In
whiten.py
there is a datatype argument to set the datatype of the output recording. The casting is performed after whitening has been performed and does not affect the datatype used for computing the whitening matrix. Up until0.100.8
, the covariance matrix was always computed infloat32
, but in0.101.0
the linerandom_data = np.astype(random_data, np.float32)
that was under this line was removed. Now if the recording isint16
the covariance matrix estimation is performed inint16
and will overflow in nearly all cases.This PR reinstates that line, for now casting to
float64
. The recordingdtype
will be cast back to the user-specifieddtype
or the input recordingdtype
, so I don't think there is any downside in using maximum precision. The only thing is, maybe we will want to move to GPU support in future, and using float32 will keep it consistent.Unless I'm missing something, I think it would be worth rolling this out in an update ASAP and making a note of this on the release note header. My understanding is the majority of recording datatypes are stored
int16
and so this regression may catch out quite a lot of users since0.101.0
and is quite hard to track down.TODOs