Closed h-mayorquin closed 1 year ago
@h-mayorquin @alejoe91 added a fix in the PR (the same test failed while working on the PR commit 957a169). I would try re-running the tests (or at this point wait for the cron) to see if the fix he added (and is now merged into main) fixes the full-tests.
Thanks Ramon. It was actually introduced here #1981 and already fixed in #2041 :)
All right, then the tests should work tomorrow when the cron job runs and we close this.
def fit_collision(
collision,
traces_with_margin,
start_frame,
end_frame,
left,
right,
nbefore,
all_templates,
unit_inds_to_channel_indices,
cut_out_before,
cut_out_after,
):
"""
Compute the best fit for a collision between a spike and its overlapping spikes.
The function first cuts out the traces around the spike and its overlapping spikes, then
fits a multi-linear regression model to the traces using the centered templates as predictors.
Parameters
----------
collision: np.ndarray
A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype).
traces_with_margin: np.ndarray
A numpy array of shape (n_samples, n_channels) containing the traces with a margin.
start_frame: int
The start frame of the chunk for traces_with_margin.
end_frame: int
The end frame of the chunk for traces_with_margin.
left: int
The left margin of the chunk for traces_with_margin.
right: int
The right margin of the chunk for traces_with_margin.
nbefore: int
The number of samples before the spike to consider for the fit.
all_templates: np.ndarray
A numpy array of shape (n_units, n_samples, n_channels) containing the templates.
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices.
cut_out_before: int
The number of samples to cut out before the spike.
cut_out_after: int
The number of samples to cut out after the spike.
Returns
-------
np.ndarray
The fitted scaling factors for the colliding spikes.
"""
from sklearn.linear_model import LinearRegression
# make center of the spike externally
sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left)
sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left)
# construct sparsity as union between units' sparsity
sparse_indices = np.array([], dtype="int")
for spike in collision:
sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]]
sparse_indices = np.union1d(sparse_indices, sparse_indices_i)
local_waveform_start = max(0, sample_first_centered - cut_out_before)
local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after)
local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices]
y = local_waveform.T.flatten()
X = np.zeros((len(y), len(collision)))
for i, spike in enumerate(collision):
full_template = np.zeros_like(local_waveform)
# center wrt cutout traces
sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start
template = all_templates[spike["unit_index"]][:, sparse_indices]
template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after]
# deal with borders
if sample_centered - cut_out_before < 0:
full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :]
elif sample_centered + cut_out_after > end_frame + right:
full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)]
else:
> full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut
E ValueError: could not broadcast input array from shape (210,4) into shape (157,4)
And also
def main_function(cls, traces, d):
templates = d["templates"]
num_templates = d["num_templates"]
num_channels = d["num_channels"]
num_samples = d["num_samples"]
overlaps = d["overlaps"]
norms = d["norms"]
nbefore = d["nbefore"]
nafter = d["nafter"]
omp_tol = np.finfo(np.float32).eps
num_samples = d["nafter"] + d["nbefore"]
neighbor_window = num_samples - 1
min_amplitude, max_amplitude = d["amplitudes"]
ignored_ids = d["ignored_ids"]
stop_criteria = d["stop_criteria"][:, np.newaxis]
vicinity = d["vicinity"]
rank = d["rank"]
num_timesteps = len(traces)
num_peaks = num_timesteps - num_samples + 1
conv_shape = (num_templates, num_peaks)
scalar_products = np.zeros(conv_shape, dtype=np.float32)
# Filter using overlap-and-add convolution
if len(ignored_ids) > 0:
mask = ~np.isin(np.arange(num_templates), ignored_ids)
> spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :])
E ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 32 is different from 4)
@h-mayorquin I think this is done now, right?
Yes, let's close this
Not sure if related to the last changes which is the only association that I have:
https://github.com/SpikeInterface/spikeinterface/pull/2041