Open Oufattole opened 4 months ago
Things to consider
"[TS @ 11:13pm] [CODE: HR] [VAL: 88] [CODE: DX] [CODE: BP/SYS] [VAL:120] ..."
and you structure your model so that the output head at each position in the sequence predicts jointly over the following things: 1. probabilities that sum to one that the next token is either (a) a TS token, (b) a CODE: * token (for all codes), or (c) a VAL token, 2. The numerical value for TS that would be observed conditional on the next token being a TS token, 3. The numerical value for VAL that would be observed conditional on the next token being a VAL token, then you can reliably train your model with teacher forcing and generate data by simply sampling from the right code and only using the predicted TS or VAL # distributions if you sample a TS or VAL code, respectively (this is true b/c the TS or VAL #s are always conditional on observing those codes, so your model is never being trained to predict conditional on not observing those codes). But, in this setting, you can't dynamically decide during generation that, even though the model's predictions give the probability that the next token is a TS is low that there should be a generated timestamp and rely on the predicted # to give what that TS would be conditional on it being observed -- this is because if the predicted probability of the TS being low is sufficiently low (e.g., 0), then the model may have already stored that information in earlier layers for this output prediction and that may affect the predicted TS (or val) # in the output layer without consequence by the loss b/c you'd never observe a real # in that setting to compare against. If, instead, you structure your sequence like "[TS] [TS = 11:13pm] [CODE: HR] [VAL] [VAL = 88] [CODE: DX] [CODE: BP/SYS] [VAL] [VAL = 120] ..."
then the prediction of the likelihood of the [VAL]
token itself is less entangled with the predicted numerical value so there may be a greater ability of the model to express conditional predictions of values or timestamps even when timestamps or values are unlikely.[CODE: DX1]
raises the likelihood of [CODE: DX2]
if those two are highly correlated than it is for the model to learn to predict both DX1
and DX2
given only historical data, so the pre-training "learning task" is easier when you flatten the sequence which may lead to limited transferability. One more:
TODO:
For adding support for kv caching, you should:
For supporting fast generation we can use the following existing approaches: For transformer decoders, we are using the x-transformers library, so we can follow their generation script for kv caching.
For Mamba we can follow their caching strategy here.
For the LSTM/RNN models we can support O(1) generation following the incremental decoding approach used in fairseq.
In the conditionally independent triplet autoregressive forecasting scheme where each token is the sum of time value and code embeddings, another option for handling numerical_values is to sample all values for all codes for every token and just select the one for the predicted code. This will have reasonable memory and compute cost as it is exactly the same head as the code prediction head (just with no softmax), but allows modeling numerical values that are conditionally dependent on the code by indexing the generated corresponding code index.
First iteration for a transformer decoder is here: https://github.com/Oufattole/meds-torch/blob/dev/src/meds_torch/models/token_forecasting.py
Generating uses key value caching.
In the conditionally independent triplet autoregressive forecasting scheme where each token is the sum of time value and code embeddings, another option for handling numerical_values is to sample all values for all codes for every token and just select the one for the predicted code. This will have reasonable memory and compute cost as it is exactly the same head as the code prediction head (just with no softmax), but allows modeling numerical values that are conditionally dependent on the code by indexing the generated corresponding code index.
Take a look at this ESGPT code
It has some things implemented for what you're talking about here in a few ways. Happy to have another chat about this as well.