ESGPT iterates over causally decoded embeddings in the dependency graph (see here), performing a classification or regression task on a subset of codes at each stage, and the final node in the dependancy graph is TTE, where it samples a time to event distribution (see here). It allows using a LogNormal or an Exponential distribution it seems.
Specifically, it performs the following three steps:
Takes encoded states of shape [batch_size, seq_len, dep_graph_len, hidden_size], where dep_graph_len represents a structured ordering of event components
Sequentially processes each level of the dependency graph (starting at index 1), where each level:
Handles specific measurement subsets (both categorical & numerical) defined in measurements_per_dep_graph_level
Generates classifications using Bernoulli/Categorical distributions and/or regression values using Normal distributions
Uses the final dependency graph node's encoding to sample time-to-event from either a LogNormal mixture or Exponential distribution, completing the event generation
This approach has three main limitations for our scale (hundreds of thousands of patients):
Predicting all events at once from a single embedding is very challenging
→ Proposed Solution: We predict only one measurement/triplet (code, time, value) at a time, making each prediction task simpler
No teacher forcing within timesteps makes learning harder
→ Proposed Solution: We use teacher forcing at every step of the triplet generation (code→time→value), allowing the model to learn from ground truth during training
Conditioning on embeddings rather than actual samples can be problematic when distribution variance is high
→ Proposed Solution: We explicitly condition on each sampled value (using the sampled code to select which time distribution to use, and concatenating the sampled time embedding before predicting values)
Proposed sampling strategy:
Using transformer decoder, generate embedding conditioned on past data $x_{<j}$
Sample code y from multinomial distribution over all codes
Sample time delta $t_y$ from zero-inflated exponential distribution parameterized specifically for code y
Concatenate $t_y$ embedding with original embedding to predict $(\mu_y, \sigma_y, p_y$) for each code:
Sample numeric value $v_y$ from $\mathcal{N}(\mu_y, \sigma_y)$ if bernoulli sample $p_y$ is $1$
We need to modify the triplet forecasting get_forecast_logits function here to perform precisely this^
Change 2: Fix time shift bug
Add a shift on the forecasted tokens. Currently, the model is predicting the same time step instead of forecasting the next time step.
TODOs:
[ ] Update get_forecast_logits to conditionally sample
Change 1: Conditional Sampling
ESGPT iterates over causally decoded embeddings in the dependency graph (see here), performing a classification or regression task on a subset of codes at each stage, and the final node in the dependancy graph is TTE, where it samples a time to event distribution (see here). It allows using a LogNormal or an Exponential distribution it seems.
Specifically, it performs the following three steps:
This approach has three main limitations for our scale (hundreds of thousands of patients):
Proposed sampling strategy:
We need to modify the triplet forecasting
get_forecast_logits
function here to perform precisely this^Change 2: Fix time shift bug
Add a shift on the forecasted tokens. Currently, the model is predicting the same time step instead of forecasting the next time step.
TODOs:
get_forecast_logits
to conditionally sample