I'm interested in learning if/how it would be possible to fit an glm-hmm (i.e. LogisticRegressionHMM, CategoricalRegressionHMM) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.
Context
To motivate this question, supposejagged_list is a list where each element is a session_list (i.e. jagged_list = [session_list_1, session_list_2, ... session_list_s]), and each session_list contains trial samples (i.e. session_list_k = [trial_1, trial_2, ... trial_t]). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).
Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.
Issues
I think there are two roadblocks that prevent this from being possible.
Dynamax does not appear to support jagged arrays due to jax implementation
Dynamax has a procedure for inputs that are batched, however it requires each batch (e.g. session) to have the same number of time steps (e.g. trials).
The current fit_em method runs the E step and M step for each batch (as opposed to E-step for each batch, M step across all batches)
However, There are separate methods for E and M steps- I'm just unsure how to properly summarize batch/session iterated E-step outputs (i.e. SuffStats) to pass into a single M-step call across all batches/sessions- any advice here would be greatly. appreciated
Questions
Is this summary of issues and comparison accurate?
Is there a way to implement the desired behavior of session-level E-steps and subject-level M-step? (e.g., via padding or using e_step and m_step methods)
Summary
I'm interested in learning if/how it would be possible to fit an
glm-hmm
(i.e.LogisticRegressionHMM
,CategoricalRegressionHMM
) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.Context
To motivate this question, suppose
jagged_list
is a list where each element is asession_list
(i.e.jagged_list = [session_list_1, session_list_2, ... session_list_s])
, and eachsession_list
contains trial samples (i.e.session_list_k = [trial_1, trial_2, ... trial_t]
). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.
Issues
I think there are two roadblocks that prevent this from being possible.
Questions
e_step
andm_step
methods)Thank you!