astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

(optional) Verify the smoothness of ODE Flows and try to train with gradient constraints #20

Open EiffL opened 2 years ago

EiffL commented 2 years ago

image

https://github.com/astrodeepnet/sbi_experiments/blob/main/notebooks/ffjord/pure_jax_ode.ipynb

-> Implements a ffjord in jax using sirens to parametrize the dynamics.

FFJORD trained by NLL: image

FFJORD trained by score loss: image

It's working pretty well on a first try. Could probably be slightly better with more optimization of the model, note that for instance in https://www.tensorflow.org/probability/examples/FFJORD_Demo they just have tiny networks, and stack 4 ffjords on top of each other.

It would be great to write down exactly the equations of the score of a ffjord to demonstrate the score is smooth.