LorenFrankLab / spyglass

Neuroscience data analysis framework for reproducible research built by Loren Frank Lab at UCSF
https://lorenfranklab.github.io/spyglass/
MIT License
94 stars 43 forks source link

value error using non local detector #1190

Closed MichaelCoulter closed 5 days ago

MichaelCoulter commented 5 days ago

i am trying to run the non-local decoder with this input and i get this error. any help would be appreciated. thank you.

input

from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1, ClusterlessDecodingSelection

selection_key = {
    "waveform_features_group_name": "CH65_12_04_all_tet",
    "position_group_name": "CH65_12_04",
    "decoding_param_name": 'CH65_1204_nonlocal',
    "nwb_file_name": nwb_file_name,
    "encoding_interval": "CH65_12_04_01",
    "decoding_interval": "CH65_12_04_01",
    "estimate_decoding_params": False,
}

ClusterlessDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

ClusterlessDecodingSelection & selection_key

ClusterlessDecodingV1.populate(selection_key)

error

19-Nov-24 12:53:55 Fitting initial conditions...
19-Nov-24 12:53:55 Fitting discrete state transition
19-Nov-24 12:53:55 Fitting continuous state transition...
19-Nov-24 12:53:56 Fitting clusterless spikes...
Encoding models: 100%
39/39 [00:22<00:00, 1.97electrode/s]
19-Nov-24 12:54:25 Computing posterior...
19-Nov-24 12:54:25 Computing log likelihood...
Local Likelihood: 100%
39/39 [2:47:01<00:00, 206.42s/electrode]
No Spike Likelihood: 100%
39/39 [00:00<00:00, 53.80cell/s]
Non-Local Likelihood: 100%
39/39 [38:28<00:00, 38.67s/electrode]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [134], line 20
     13 ClusterlessDecodingSelection.insert1(
     14     selection_key,
     15     skip_duplicates=True,
     16 )
     18 ClusterlessDecodingSelection & selection_key
---> 20 ClusterlessDecodingV1.populate(selection_key)

File ~/spyglass/src/spyglass/utils/dj_mixin.py:589, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    587 if use_transact:  # Pass single-process populate to super
    588     kwargs["processes"] = processes
--> 589     return super().populate(*restrictions, **kwargs)
    590 else:  # No transaction protection, use bare make
    591     for key in keys:

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:248, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    242 if processes == 1:
    243     for key in (
    244         tqdm(keys, desc=self.__class__.__name__)
    245         if display_progress
    246         else keys
    247     ):
--> 248         status = self._populate1(key, jobs, **populate_kwargs)
    249         if status is True:
    250             success_list.append(1)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:315, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    313 self.__class__._allow_insert = True
    314 try:
--> 315     make(dict(key), **(make_kwargs or {}))
    316 except (KeyboardInterrupt, SystemExit, Exception) as error:
    317     try:

File ~/spyglass/src/spyglass/decoding/v1/clusterless.py:243, in ClusterlessDecodingV1.make(self, key)
    238             logger.warning(
    239                 f"Interval {interval_start}:{interval_end} is empty"
    240             )
    241             continue
    242         results.append(
--> 243             classifier.predict(
    244                 position_time=interval_time,
    245                 position=position_info.loc[interval_start:interval_end][
    246                     position_variable_names
    247                 ].to_numpy(),
    248                 spike_times=spike_times,
    249                 spike_waveform_features=spike_waveform_features,
    250                 time=interval_time,
    251                 **predict_kwargs,
    252             )
    253         )
    254     results = xr.concat(results, dim="intervals")
    256 # Save discrete transition and initial conditions

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1639, in ClusterlessDetector.predict(self, spike_times, spike_waveform_features, time, position, position_time, is_missing, discrete_transition_covariate_data, cache_likelihood, n_chunks)
   1632 if discrete_transition_covariate_data is not None:
   1633     self.discrete_state_transitions_ = predict_discrete_state_transitions(
   1634         self.discrete_transition_design_matrix_,
   1635         self.discrete_transition_coefficients_,
   1636         discrete_transition_covariate_data,
   1637     )
-> 1639 (
   1640     acausal_posterior,
   1641     acausal_state_probabilities,
   1642     marginal_log_likelihood,
   1643     _,
   1644     _,
   1645 ) = self._predict(
   1646     time=time,
   1647     log_likelihood_args=(
   1648         position_time,
   1649         position,
   1650         spike_times,
   1651         spike_waveform_features,
   1652     ),
   1653     is_missing=is_missing,
   1654     cache_likelihood=cache_likelihood,
   1655     n_chunks=n_chunks,
   1656 )
   1658 return self._convert_results_to_xarray(
   1659     time,
   1660     acausal_posterior,
   1661     acausal_state_probabilities,
   1662     marginal_log_likelihood,
   1663 )

ValueError: too many values to unpack (expected 5)
edeno commented 5 days ago

This was fixed in the latest version of the non_local_detector. Please update to v0.6.8