patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Implement Latent SDE #101

Closed anh-tong closed 2 years ago

anh-tong commented 2 years ago

Hi,

I'm new to both JAX and diffrax.

I'm familiar with latent SDE in Pytorch version torchsde. I find difficult to implement in diffrax.

Latent SDE has an augmented term which computes KL between posterior and prior, solving a ODE parallelly.

To compute the vector field for this ODE, we need to take avage over a batch of data,

JAX handles batch data via vmap which is done after calling diffeqsolve. But diffeqsolve requires to specify vector fields beforehand while the vector field is not computed via a batch yet. Do you have any suggestions for this?

I think another alternative that takes integral (solve ODE without averging vector field) first then taking expectation (average over batch size). This may double memory space. Do you think this might be viable?

Thank you!

patrick-kidger commented 2 years ago

In terms of handling batches: the expectation for the KL divergence should be something you can calculate after the diffeqsolve. It doesn't have to happen inside the solve.

(And if for some reason you really want to do the expectation inside the diffeqsolve you can still nest diffeqsolve-vectorfield-(vmap / KL).)

FWIW I have a first pass at the relevant pieces of the KL divergence calculation here. I never got around to putting together a latent SDE example, though, so the above might still be totally buggy/wrong.

anh-tong commented 2 years ago

Thanks for the pointer. That'll be enough for the implementation. For now, I just close this issue. I will share the result if I have any.

anh-tong commented 2 years ago

Hi again, your code (with a small change) works pretty well.

https://colab.research.google.com/drive/14fRvx6jcIZPoimKYA5KpQduRfx8xc5dp?usp=sharing

Would you mind if I open a pull request adding latent SDE in examples folder?

patrick-kidger commented 2 years ago

Amazing, that'd be great.

If you also add it to mkdocs.yml then it'll appear in the documentation. (I'd suggest perhaps labelling the two neural SDEs examples as "Neural SDE (GAN)" and "Neural SDE (VAE)"?)

Make sure to follow CONTRIBUTING.md to ensure the appropriate formatting automatically gets handled.

By the way, if you're feeling keen: something I'd quite like to do is combine both examples together into a VAE-GAN, where the decoder of the VAE is the generator of the GAN. Appropriately tuned this should be far better than either of them alone. (I've been musing about follow-up papers here, haha.)

anh-tong commented 2 years ago

Ok, I'll make a pull request based on your suggestion.

The idea of VAE-GAN seems interesting to me. I work on that as well. Let's discuss further later.