This PR should add a variation on the ramp bijector which can include several ramps at once, in order to create complex bijections similar to what is achieved in Spline Flow.
In early testing, I can achieve a very nice fit to the two moons with just 2 coupling layers, of 8 components each (and that may already be more than necessary).
However, in implementing this, several things happened:
The newton solver implemented by @b-remy was not able to always converge nicely when trying to invert a multi-ramp bijector
I tried to use the jaxOpt package and its builtin bisection root finder ( https://jaxopt.github.io/stable/root_finding.html ). On the bright side, it is pretty fast, and handles complex bijections with ease. However, for a reason I do not understand, second order gradients were misbehaving, as in, I get the wrong value_and_grad. Haven't been able to solve it.
This PR should add a variation on the ramp bijector which can include several ramps at once, in order to create complex bijections similar to what is achieved in Spline Flow.
In early testing, I can achieve a very nice fit to the two moons with just 2 coupling layers, of 8 components each (and that may already be more than necessary).
However, in implementing this, several things happened:
I'll leave it as a draft PR for now, because the bijector seems to be working but not its gradients, which is annoying