patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

can I make a continuous normalizing flow faster than real nvp? #256

Open agpcd opened 1 year ago

agpcd commented 1 year ago

I implemented the CNF example and added it as a head to a transformer to do inference over some continuous variables in a probabilistic program, however, it's wayyyyy slower than an my equinox implementation of real nvp inspired by this excelent tutorial.

However, the elegance of a CNF is just too good to just ignore, any recommendations?

patrick-kidger commented 1 year ago

Yes you can! Train it as a diffusion model ;) Seriously, diffusion models are just fancy ways to train a CNF -- at inference time they're the same thing, namely just an ODE. There's an Equinox implementation of score-based diffusions here that you can follow.

The big difference is that a CNF is trained directly using log-likelihood (and thus can be composed into a larger autodifferentiable program very easily), but diffusions use their cunningly derived score-matching loss (which are much harder to fit into larger autodifferentiable programs). I mention this as you say your'e doing this as a head to a transformer, and I'm not sure how cleanly the diffusion approach will work unless the transformer is entirely pretrained. (=not differentiated through when training the diffusion.)

The fact that CNFs take a long time to train is the reason they're not really trained via log-likelihoods any more. They simply weren't scalable -- i.e. the exact issue you're bumping into is a known issue in the literature.