Open NonoMalpi opened 1 year ago
t=...
, this is a problem that is substantially simpler when using Diffrax. This allows you to write down your function for just a single batch element, and then just vmap it. This means you don't need to worry about the details of getting different times lined up across batch elements, and you can just write down your loss function in a straightforward way.t
is all you need to keep evolving for times > t
. In this case you probably don't want to do this. (One thing you might consider, though, is gradually lengthening the time series: start by fitting your model to just all time points <T1, then all time points <T2, etc., for some values of T1/T2/... that you pick arbitrarily. This helps the model to gradually learn more and more complicated dynamics.)Good luck!
Hi Patrick! Congratulations on your research work on neural differential equations. It's quite impressive, and thank you for the
torchcde
and Diffrax libraries.I've been experimenting with the torchcde module for some time now. I've read the repository and related papers: https://arxiv.org/abs/2005.08926, https://arxiv.org/abs/2106.11028. Currently, I'm working on a time series prediction problem using neural CDEs. I will migrate it to Diffrax, but I have a question, and I think your experience can help me to address it.
In a nutshell, I'm predicting a substance concentration in blood, denoted as $Y$, from different patients. This concentration is irregularly sampled within and across patients. For instance, patient A has eleven measurements in ~50 minutes, patient B has only two measurements in ~4 hours, while patient C has five measurements in ~2 hours. The goal is to estimate $Y$ based on a set of medical signals $X$ (e.g., heart rate, $O_2$ level) that are almost uniformly sampled for all patients (handling $X$ is not an issue). My objective is to predict $Y$ at time $t_n$ , considering all historical information [ $X$ and $Y$ (at least an initial condition of $Y$)] from $t0$ to $t{n-1}$ for each patient. This means I would like to have ten predictions of $Y$ for patient A, one prediction for patient B, and so on. This setup is quite different from the examples I have seen, as neuralCDEs are mainly used for classification or static regression tasks (such as the BeijingPM10 or the LOS examples in https://arxiv.org/abs/2106.11028).
I've tried several strategies:
Training a neuralCDE for each patient and validating it using a rolling window strategy. However, the out-of-sample predictions seem to be quite similar to the last value of $Y$ observed, indicating possible overfitting. Moreover, the
coeffs
obtained from interpolation change their size across windows, which raises concerns about the approach's validation and effectiveness.Instead of using the window strategy, I considered replacing
t=X.interval
witht=X.grid_points
in thetorchcde.cdeint(...)
function (assuming that the hidden channel, dim=1, directly represents $Y$). This change would allow me to obtain an estimated array $\hat{Y}$ for all time steps considered, but the true value of $Y$ is recorded only at specific steps. Not sure about how to compute the loss function in this case.Another approach I considered is splitting $X$ and $Y$ into one-step $Y$-related measurements for all patients. For example, if $Y$ is available for patient A at $t{a1}$, $t{a2}$, $t{a3}$ ..., I would divide $X$, $Y$ for patient A into batches $[t{0}, t{a1-1}], [t{a1}, t_{a2-1}]$, and so on. I would apply a similar strategy for patient B, and then group all batches from all patients to follow the irregular data strategy as commented in irregular_data.py. This approach would allow me to perform train/validation/test splits, ensuring that all sets have the same
coeffs
length and making testing more manageable. However, I'm concerned that with this strategy I'm losing information as predicting $Y$ at $t{a2}$ would mean missing records before $t{a1}$ that may be useful.As you can see, it's a question related to preprocessing or train-test strategies, but with the way of input data for neuralCDEs, it might be worth thinking it over carefully. Any comments would be greatly appreciated. Thank you very much!