Closed anh-tong closed 2 years ago
In terms of handling batches: the expectation for the KL divergence should be something you can calculate after the diffeqsolve
. It doesn't have to happen inside the solve.
(And if for some reason you really want to do the expectation inside the diffeqsolve you can still nest diffeqsolve-vectorfield-(vmap / KL).)
FWIW I have a first pass at the relevant pieces of the KL divergence calculation here. I never got around to putting together a latent SDE example, though, so the above might still be totally buggy/wrong.
Thanks for the pointer. That'll be enough for the implementation. For now, I just close this issue. I will share the result if I have any.
Hi again, your code (with a small change) works pretty well.
https://colab.research.google.com/drive/14fRvx6jcIZPoimKYA5KpQduRfx8xc5dp?usp=sharing
Would you mind if I open a pull request adding latent SDE in examples
folder?
Amazing, that'd be great.
If you also add it to mkdocs.yml
then it'll appear in the documentation. (I'd suggest perhaps labelling the two neural SDEs examples as "Neural SDE (GAN)" and "Neural SDE (VAE)"?)
Make sure to follow CONTRIBUTING.md to ensure the appropriate formatting automatically gets handled.
By the way, if you're feeling keen: something I'd quite like to do is combine both examples together into a VAE-GAN, where the decoder of the VAE is the generator of the GAN. Appropriately tuned this should be far better than either of them alone. (I've been musing about follow-up papers here, haha.)
Ok, I'll make a pull request based on your suggestion.
The idea of VAE-GAN seems interesting to me. I work on that as well. Let's discuss further later.
Hi,
I'm new to both
JAX
anddiffrax
.I'm familiar with latent SDE in Pytorch version
torchsde
. I find difficult to implement indiffrax
.Latent SDE has an augmented term which computes KL between posterior and prior, solving a ODE parallelly.
To compute the vector field for this ODE, we need to take avage over a batch of data,
JAX
handles batch data viavmap
which is done after callingdiffeqsolve
. Butdiffeqsolve
requires to specify vector fields beforehand while the vector field is not computed via a batch yet. Do you have any suggestions for this?I think another alternative that takes integral (solve ODE without averging vector field) first then taking expectation (average over batch size). This may double memory space. Do you think this might be viable?
Thank you!