probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

dynmax glm-hmm using jagged input arrays with differing E and M steps #362

Open jess-breda opened 1 month ago

jess-breda commented 1 month ago

Summary

I'm interested in learning if/how it would be possible to fit an glm-hmm (i.e. LogisticRegressionHMM, CategoricalRegressionHMM) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.

Context

To motivate this question, supposejagged_list is a list where each element is a session_list (i.e. jagged_list = [session_list_1, session_list_2, ... session_list_s]), and each session_list contains trial samples (i.e. session_list_k = [trial_1, trial_2, ... trial_t]). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).

Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.

Issues

I think there are two roadblocks that prevent this from being possible.

  1. Dynamax does not appear to support jagged arrays due to jax implementation
  1. The current fit_em method runs the E step and M step for each batch (as opposed to E-step for each batch, M step across all batches)

Questions

  1. Is this summary of issues and comparison accurate?
  2. Is there a way to implement the desired behavior of session-level E-steps and subject-level M-step? (e.g., via padding or using e_step and m_step methods)
  3. Any additional thoughts or suggestions?

Thank you!

atlaie commented 4 weeks ago

+1! I'd also be very interested in having this implemented in Dynamax