google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.51k stars 194 forks source link

Irregular data and sampling posterior in latent_sde_lorenz.py #127

Open gopal-iyer opened 1 year ago

gopal-iyer commented 1 year ago

Hi,

This is very cool work and the documentation is great!

Two questions:

  1. I have already checked out @patrick-kidger's answer to Issue #106 on irregular time series for latent SDE models. Is it also a good idea to include missingness channels and fill-forward shorter samples as in torchcde/irregular_data.py when the encoder is a GRU and not a neural CDE? I would expect that the inclusion of these extra channels would affect the initialization of the encoder so that the appropriate input_size is specified. However, I'm a little confused about how this would change the contextualize() and drift (f()) functions, since each sample would now have its own unique set of timestamps.

  2. In latent_sde_lorenz.py, I believe the sample() function only samples from the complete learned prior distribution. In order to reconstruct data by conditionally sampling the prior given a 'warm up' time series, would I need to write a different sample() function? I'm guessing I would just need to include context before sampling z0 (as in lines 173–177 in latent_sde_lorenz.py), with 'drift' set to 'f' instead of 'h', before sampling it as usual in the sample() function. Is this accurate?

Any tips would be appreciated. Thanks!