-> Implements a ffjord in jax using sirens to parametrize the dynamics.
FFJORD trained by NLL:
FFJORD trained by score loss:
It's working pretty well on a first try. Could probably be slightly better with more optimization of the model, note that for instance in https://www.tensorflow.org/probability/examples/FFJORD_Demo they just have tiny networks, and stack 4 ffjords on top of each other.
It would be great to write down exactly the equations of the score of a ffjord to demonstrate the score is smooth.
https://github.com/astrodeepnet/sbi_experiments/blob/main/notebooks/ffjord/pure_jax_ode.ipynb
-> Implements a ffjord in jax using sirens to parametrize the dynamics.
FFJORD trained by NLL:
FFJORD trained by score loss:
It's working pretty well on a first try. Could probably be slightly better with more optimization of the model, note that for instance in https://www.tensorflow.org/probability/examples/FFJORD_Demo they just have tiny networks, and stack 4 ffjords on top of each other.
It would be great to write down exactly the equations of the score of a ffjord to demonstrate the score is smooth.