Oufattole / meds-torch

MIT License
15 stars 1 forks source link

Update the Triplet autoregressive model to perform conditional sampling of values conditioned on the code #107

Open Oufattole opened 1 month ago

Oufattole commented 1 month ago

Change 1: Conditional Sampling

ESGPT iterates over causally decoded embeddings in the dependency graph (see here), performing a classification or regression task on a subset of codes at each stage, and the final node in the dependancy graph is TTE, where it samples a time to event distribution (see here). It allows using a LogNormal or an Exponential distribution it seems.

Specifically, it performs the following three steps:

  1. Takes encoded states of shape [batch_size, seq_len, dep_graph_len, hidden_size], where dep_graph_len represents a structured ordering of event components
  2. Sequentially processes each level of the dependency graph (starting at index 1), where each level:
    • Handles specific measurement subsets (both categorical & numerical) defined in measurements_per_dep_graph_level
    • Generates classifications using Bernoulli/Categorical distributions and/or regression values using Normal distributions
  3. Uses the final dependency graph node's encoding to sample time-to-event from either a LogNormal mixture or Exponential distribution, completing the event generation

This approach has three main limitations for our scale (hundreds of thousands of patients):

  1. Predicting all events at once from a single embedding is very challenging → Proposed Solution: We predict only one measurement/triplet (code, time, value) at a time, making each prediction task simpler
  2. No teacher forcing within timesteps makes learning harder → Proposed Solution: We use teacher forcing at every step of the triplet generation (code→time→value), allowing the model to learn from ground truth during training
  3. Conditioning on embeddings rather than actual samples can be problematic when distribution variance is high → Proposed Solution: We explicitly condition on each sampled value (using the sampled code to select which time distribution to use, and concatenating the sampled time embedding before predicting values)

Proposed sampling strategy:

  1. Using transformer decoder, generate embedding conditioned on past data $x_{<j}$
  2. Sample code y from multinomial distribution over all codes
  3. Sample time delta $t_y$ from zero-inflated exponential distribution parameterized specifically for code y
  4. Concatenate $t_y$ embedding with original embedding to predict $(\mu_y, \sigma_y, p_y$) for each code:
    • Sample numeric value $v_y$ from $\mathcal{N}(\mu_y, \sigma_y)$ if bernoulli sample $p_y$ is $1$

We need to modify the triplet forecasting get_forecast_logits function here to perform precisely this^

Change 2: Fix time shift bug

Add a shift on the forecasted tokens. Currently, the model is predicting the same time step instead of forecasting the next time step.


TODOs: