LorenFrankLab / spyglass

Neuroscience data analysis framework for reproducible research built by Loren Frank Lab at UCSF
https://lorenfranklab.github.io/spyglass/
MIT License
94 stars 43 forks source link

Model inclusion of track graph in Decoding notebooks #1157

Open CBroz1 opened 1 month ago

CBroz1 commented 1 month ago

This line mentions the possibility of including a track graph in decoding

https://github.com/LorenFrankLab/spyglass/blob/03e39960e31bb358ba3f112220a5fdbb2e6acc76/notebooks/py_scripts/41_Decoding_Clusterless.py#L259-L261

In my brief attempt, I did not find the process of including this graph straight forward, needing to remap kwargs across a fetched value and classes from non_local_detector. I hit errors with the embedded ObservationModel having an empty name

Partial error stack ```python _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ src/spyglass/utils/dj_mixin.py:612: in populate return super().populate(*restrictions, **kwargs) ../datajoint-python/datajoint/autopopulate.py:254: in populate status = self._populate1(key, jobs, **populate_kwargs) ../datajoint-python/datajoint/autopopulate.py:322: in _populate1 make(dict(key), **(make_kwargs or {})) src/spyglass/decoding/v1/clusterless.py:212: in make classifier.fit( ../../miniconda3/envs/spy/lib/python3.9/site-packages/non_local_detector/models/base.py:960: in fit self._fit( ../../miniconda3/envs/spy/lib/python3.9/site-packages/non_local_detector/models/base.py:493: in _fit self.initialize_state_index() _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = ClusterlessDetector(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_...ke=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented']) def initialize_state_index(self): self.n_discrete_states_ = len(self.state_names) bin_sizes = [] state_ind = [] is_track_interior = [] for ind, obs in enumerate(self.observation_models): if obs.is_local or obs.is_no_spike: bin_sizes.append(1) state_ind.append(ind * np.ones((1,), dtype=int)) is_track_interior.append(np.ones((1,), dtype=bool)) else: environment = self.environments[ > self.environments.index(obs.environment_name) ] E ValueError: '' is not in list ../../miniconda3/envs/spy/lib/python3.9/site-packages/ ```
edeno commented 1 month ago

Not quite sure what you're running here. It seems like the environment name given for the observation model is not one specified in the list of Environments given to the model.

CBroz1 commented 1 month ago
Run snippet ```python @pytest.fixture def some_name(track_graph): from non_local_detector.environment import Environment from non_local_detector.models import ContFragClusterlessClassifier from spyglass.decode import v1 as decode_v1 graph_entry = track_graph.fetch1() # Restricted table class_kwargs = dict( clusterless_algorithm_params={ "block_size": 10000, "position_std": 12.0, "waveform_std": 24.0, }, environments=[ Environment( environment_name=graph_entry["track_graph_name"], track_graph=track_graph.get_networkx_track_graph(), edge_order=graph_entry["linear_edge_order"], edge_spacing=graph_entry["linear_edge_spacing"], ) ], ) params_pk = {"decoding_param_name": "contfrag_clusterless"} decode_v1.core.DecodingParameters.insert_default() decode_v1.core.DecodingParameters.insert1( { **params_pk, "decoding_params": ContFragClusterlessClassifier(**class_kwargs), "decoding_kwargs": dict(), }, skip_duplicates=True, ) ```
edeno commented 1 month ago

If you omit: environment_name=graph_entry["track_graph_name"], then it should work, otherwise you have to specify the same name in the ObservationModel.