Oufattole / meds-torch

MIT License
16 stars 2 forks source link

Autoregressive Modeling #30

Open Oufattole opened 4 months ago

Oufattole commented 4 months ago
mmcdermott commented 4 months ago

Things to consider

  1. Flattening sequences results in blowing up the sequence length too much (may be an unavoidable challenge, and likely the least bad option). This also risks making the pre-training task more local, as discussed below.
  2. Predicting multiple aspects of an event (e.g., timestamp and code, multiple codes, code and value) at one time introduces probabilistic conditional independence assumptions that are not realistic.
  3. Whatever logic is used, key-value caching must be supported and must be tested and validated. This is very challenging to get right, especially if solutions involve nesting aspects of events or repeated processing of the same sequence with differing inputs to reflect updated views of the data (this should almost certainly be avoided).
  4. Certain forms of training or generation may prohibit certain abilities to interrogate the model or generate conditional futures in specific ways:
    • If you always order codes within a patient ID and timestamp in any consistent manner, then you won't be able to arbitrarily introduce codes during generation to see what the model thinks would happen if that code were observed. Adding orders like this also introduces semi-artificial probabilistic relationships the model needs to learn which may be detrimental (e.g., if just by chance the model generates a code late in the ordering, it will no longer generate codes earlier in the ordering; or if the model by chance generates an out of order code sequence, all subsequent events will be very confused).
    • If you simultaneously predict a boolean of whether a numerical value will be observed for a given code that has been observed or generated and the value that would be seen conditional on that boolean being true, then you may not be able to reliably toggle on a value to be observed during generation if the model does not think one is likely. E.g., if your sequence looks like "[TS @ 11:13pm] [CODE: HR] [VAL: 88] [CODE: DX] [CODE: BP/SYS] [VAL:120] ..." and you structure your model so that the output head at each position in the sequence predicts jointly over the following things: 1. probabilities that sum to one that the next token is either (a) a TS token, (b) a CODE: * token (for all codes), or (c) a VAL token, 2. The numerical value for TS that would be observed conditional on the next token being a TS token, 3. The numerical value for VAL that would be observed conditional on the next token being a VAL token, then you can reliably train your model with teacher forcing and generate data by simply sampling from the right code and only using the predicted TS or VAL # distributions if you sample a TS or VAL code, respectively (this is true b/c the TS or VAL #s are always conditional on observing those codes, so your model is never being trained to predict conditional on not observing those codes). But, in this setting, you can't dynamically decide during generation that, even though the model's predictions give the probability that the next token is a TS is low that there should be a generated timestamp and rely on the predicted # to give what that TS would be conditional on it being observed -- this is because if the predicted probability of the TS being low is sufficiently low (e.g., 0), then the model may have already stored that information in earlier layers for this output prediction and that may affect the predicted TS (or val) # in the output layer without consequence by the loss b/c you'd never observe a real # in that setting to compare against. If, instead, you structure your sequence like "[TS] [TS = 11:13pm] [CODE: HR] [VAL] [VAL = 88] [CODE: DX] [CODE: BP/SYS] [VAL] [VAL = 120] ..." then the prediction of the likelihood of the [VAL] token itself is less entangled with the predicted numerical value so there may be a greater ability of the model to express conditional predictions of values or timestamps even when timestamps or values are unlikely.
  5. If the sequence expansion and simplifications to enable easy generation make the pre-training task easier, the model may be forced to learn fewer complex relationships in order to optimize the apparent loss to the same degree. In other words, a model trained to simultaneously (and independently) predict the set of all codes observed at the next unique timepoint is forced to simultaneously assess the risk of kidney disease, liver disease, other complications, etc., but if the model is trained to predict codes one by one, it may default to only predicting the likelihood of a rare complication in a subsequent timestamp if it has already seen highly correlated diagnoses at that timestamp, limiting your ability to rely on the model to predict complex things about the future from a static record. Said differently, it is easier for the model to learn to predict that seeing [CODE: DX1] raises the likelihood of [CODE: DX2] if those two are highly correlated than it is for the model to learn to predict both DX1 and DX2 given only historical data, so the pre-training "learning task" is easier when you flatten the sequence which may lead to limited transferability.
  6. We don't actually have good ways to evaluate autoregressive generation over medical data. We have proxies for this -- e.g., the ESGPT zero-shot evaluation codebase can help you take a task dataframe, restrict raw data to just the data that would be input to that task, generate possible synthetic futures for that data, re-interpret those synthetic futures as empirical task labels (as defined through a labeling function the user writes), then produce predictions by aggregating empirical task labels into predicted probabilities, but this methodology is imperfect. It only applies to tasks and also is limited in that predictions can only be made if sufficient time passes in the generated synthetic futures to encompass the range of events needed, so this approach really only emulates a "predict or defer" approach, not a pure prediction approach like fine-tuning. It also has limited ability to assess performance over time. Also, if your sequences are binned by time, then your ability to predict at all for tasks not on that time scale is greatly reduced.
  7. If the model gets confused or there is a bug or you just get unlucky, the model can generate absurd numbers of codes or numbers of predictions in ways that can cause downstream computational issues in unexpected manners. In particular, the model can never move on to a new timestamp, if your sequence is flattened. If your generation code waits for the model to generate up to a certain time delta, this would induce an infinite loop that your code might not naively catch. If your code nests the predicted codes or values within time bins and your model predicts an extremely large set of codes for the subsequent events, both the computational time and memory requirements to process that subsequent event and subsequent sequence elements will grow enormously and cause major slowdowns that are hard to diagnose.
  8. There is no way to universally "prompt" an autoregressive model in EHR space -- you can only generate a huge number of possible continuations and filter to those of interest, and this may also produce a biased sample depending on the properties of interest and how appropriately those are specified and the filtering occurs. If you instead try to modify the input sequence at some point during generation (even the very beginning) to force the model to output something unexpected, you risk moving the sample to either a biased region of the space or to a region of the input outside of the model's effective support, thereby rendering future generation outputs less reliable and (worse) the predictions of likelihood of the full data less reliable.
  9. Many concepts from pure categorical sequence generation do not apply to sequence generation that includes continuous values or time to event. E.g., in language modeling we can, during generation, sample the "most likely" next token (e.g., the mode of the classification distribution). If our distributional output for TTE prediction is an exponential distribution the "mode" of that distribution is universally 0, which is not in the support of that distribution. So there is no way to sample the "most likely" next TTE in a continuous sense. If you sample the "expected" next sample (e.g., the mean of the distribution), you can still produce highly unlikely samples if your distribution is bimodal. The most reliable approach that I know of (but may still be biased if MLE under misspecification of predicted distribution and target distribution does not yield a sampling process that would look realistic) is to exclusively sample the next observation from the predicted probability distribution directly and not use any centralization function. Relatedly, though I can't currently remember the details, there were some key barriers to using an analog of beam search in this setting.
  10. If you don't produce distributional outputs for all possible generation targets, but instead a point value (this is most commonly a strategy for continuous outputs e.g., TTE or values), then (a) your model is effectively universally doing "mean" sampling over the distribution that is implicitly captured by the loss used (e.g., I believe MSE is equivalent to predicting a Gaussian distribution around the predicted value) which may not be appropriate for different regression tasks and (b) you lose the ability to capture the effective variance between subsequent generation runs that you need in order to effectively sample over the true space of possible trajectories. E.g., for (b) if you predict a point value for TTE, then every sample future you generate for the patient starting at the end of a complete event will start with the exact same predicted time to next event, which is not realistic. Note that, as MSE equates to a Gaussian RV prediction at the mean, the model has some (implicit) notion of variance around this RV that factored into its modeling process, but by using the pointwise value alone, you are ignoring this information. Critical to this problem in (b) is that if you do this naively, where for all continuous values you predict a pointwise output and for categorical you predict in the same way an LLM would with some temperature and resampling and whatever, then you are applying different effective sampling processes to different parts of your sequence, which means you will observe a fake and biased variance around your samples over time b/c you will have variance for categoricals but not continuous values (subject to the same inputs). ESGPT has output layers that can be repurposed to predict probabilistic distributions for exponential, gaussian, and mixture of longormal distribution values.
  11. When predicting distributions, questions about normalization, loss scale, stability of distributional parameters, and suitability of probabilstic estimators become complex and important. An option to simplify all of this stuff down is to convert all numerical values to digits in a string and literally those digits in sequence with a categorical estimator up to the recorded precision in the raw data. This eliminates all numerical distributions at the further expense of more locality of task and a longer sequence length.
  12. In predicting possible futures for medicine we often care more about calibration, in particular over rare events, than about predicting a "likely" sample with high probability. E.g., in LLMs, it is fine if the model is not calibrated to true human output b/c it always produces something that is a likely output for the input, but if a medical forecaster always predicts that the patient is healthy ad nauseam because that is the "most likely" option, then the forecaster will not be suited for risk estimation of long-term disease. This does not impact the validity of autoregressive training, which just learns a likelihood function over sequences, but does impact the validity or utility of generation strategies that focus on increasing the likelihood that the generated output is something that the model would score as highly likely (e.g., beam search).
mmcdermott commented 4 months ago

One more:

  1. If you want to either in your input, or in your sequence, or in your temporal position embedding somehow depend on time-derived signals -- here I use this term very specifically to mean things that you can deterministically compute from the timestamp of an event and static data about the subject, such as age or time of day or season of year -- into your model, then you need to have the capability to efficiently and correctly compute those things on the fly during generation after you sample the next timepoint, lest you risk introducing errors or collapsing the generated sample out of the support of the model.
Oufattole commented 4 months ago

TODO:

mmcdermott commented 4 months ago

For adding support for kv caching, you should:

  1. Just emulate existing trasnformer model code, such as that from NVIDIA in their Megatron repos or HuggingFace's code, in particular the Mixtral models or any other recent, modern LLM.
  2. Identify a set of properties you want to test during generation that you can easily verify -- e.g., that the ouptuts of a layer should be identical on the last token if you give it the full sequence or if you give it the appropriate kv cache and let it just predict the last token. You can see some examples of tests like these (among other tests) in the ESGPT codebase, in some of the files below:
Oufattole commented 4 months ago

For supporting fast generation we can use the following existing approaches: For transformer decoders, we are using the x-transformers library, so we can follow their generation script for kv caching.

For Mamba we can follow their caching strategy here.

For the LSTM/RNN models we can support O(1) generation following the incremental decoding approach used in fairseq.

Oufattole commented 4 months ago

In the conditionally independent triplet autoregressive forecasting scheme where each token is the sum of time value and code embeddings, another option for handling numerical_values is to sample all values for all codes for every token and just select the one for the predicted code. This will have reasonable memory and compute cost as it is exactly the same head as the code prediction head (just with no softmax), but allows modeling numerical values that are conditionally dependent on the code by indexing the generated corresponding code index.

Oufattole commented 4 months ago

First iteration for a transformer decoder is here: https://github.com/Oufattole/meds-torch/blob/dev/src/meds_torch/models/token_forecasting.py

Generating uses key value caching.

mmcdermott commented 4 months ago

In the conditionally independent triplet autoregressive forecasting scheme where each token is the sum of time value and code embeddings, another option for handling numerical_values is to sample all values for all codes for every token and just select the one for the predicted code. This will have reasonable memory and compute cost as it is exactly the same head as the code prediction head (just with no softmax), but allows modeling numerical values that are conditionally dependent on the code by indexing the generated corresponding code index.

Take a look at this ESGPT code

It has some things implemented for what you're talking about here in a few ways. Happy to have another chat about this as well.