Eden-Kramer-Lab / replay_trajectory_classification

State space models for decoding hippocampal trajectories and determining their type using sorted or clusterless data
MIT License
44 stars 16 forks source link

Translating API from 0.x to 1.x #30

Closed yuvalwas closed 7 months ago

yuvalwas commented 8 months ago

Hello, I am somewhat lost trying to translate between the two interfaces. My data is from a linear track and I want to have the classifier represent the two directions separately to account for remapping between directions. If I understand correctly, in 0.x I would do something like this:

clusterless_algorithm = 'multiunit_likelihood'
clusterless_algorithm_params = {
    model_kwargs': {
        bandwidth': np.array([mark_std, mark_std, mark_std, mark_std, pos_std]) 
    }
}

continuous_transition_types = [
    ['random_walk', 'uniform',  'uniform',   'uniform'],
    ['uniform',            'uniform',  'uniform',   'uniform''],
    ['uniform',            'uniform',  'random_walk', 'uniform'],
    ['uniform',            'uniform',  'uniform',   'uniform'],
    ]

encoding_group_to_state = [-1, -1, -1, 1, 1, 1] # corresponds to direction

classifier = ClusterlessClassifier(movement_var=movement_var,
                                   place_bin_size=pos_bin_size,
                                   clusterless_algorithm=clusterless_algorithm,
                                   clusterless_algorithm_params=clusterless_algorithm_params,
                                   continuous_transition_types=continuous_transition_types,
                                   discrete_transition_diag=state_diag_val)

classifier.fit(position,
               multiunits,
               encoding_group_labels=direction,
               encoding_group_to_state=encoding_group_to_state,
               is_training=is_training)

My translation attempt to 1.x is:

clusterless_algorithm = 'multiunit_likelihood'
clusterless_algorithm_params = {
          'mark_std' : mark_std, 
          'position_std' : pos_std
}

continuous_transition_types = [
    [RandomWalk(movement_var=movement_var), Uniform(),   Uniform(),                             Uniform()],
    [Uniform(),                             Uniform(),   Uniform(),                             Uniform()],
    [Uniform(),                             Uniform(),   RandomWalk(movement_var=movement_var), Uniform()],
    [Uniform(),                             Uniform(),   Uniform(),                             Uniform()],
    ]

encoding_group_to_state = [-1, -1, 1, 1] # corresponds to direction

classifier = ClusterlessClassifier(environment=Environment(place_bin_size=pos_bin_size),
                                   continuous_transition_types=continuous_transition_types,
                                   clusterless_algorithm=clusterless_algorithm,
                                   clusterless_algorithm_params=clusterless_algorithm_params,
                                   discrete_transition_diag=DiagonalDiscrete(state_diag_val))

classifier.fit(position,
               multiunits,
               encoding_group_labels=direction,
               encoding_group_to_state=encoding_group_to_state,
               is_training=is_training)

Problems:

  1. encoding_group_to_state is no longer a keyword of ClusterlessClassifier.fit and I don't know what do to with it.
  2. I don't understand from the docs the distinction between different encoding groups and different environments, which of them I should use, and how.

Thank you

edeno commented 8 months ago

Hi @yuvalwas ,

I believe this notebook should help: https://github.com/Eden-Kramer-Lab/replay_trajectory_classification/blob/master/notebooks/work_in_progress/Test_Inbound_Outbound.ipynb

Basically the same model but using the ObservationModel to define the groups and the encoding_group_labels to replace encoding_group_to_state