Closed EiffL closed 2 years ago
With 1 coupling layer and
loss = nll
loss = score diff/1000
loss = nll + score diff/1000
With 2 CL and
loss = nll
loss = score diff / 1000
It should work a little bit better than this I think, but probably depends on your training strategy. I start at lr=0.001, run it for "long enough", and then lower the learning rate.
This PR adds a bunch of things:
New inverse tools for making bijections without analytic inverses, based on jaxopt because the newton root finding method we had before wasnt stable enough. Note that for now, we can't directly use the implicit grads from jaxopt because of what looks like a little bug (https://github.com/google/jaxopt/issues/141). At least for now I've repurposed the implicit grads from @b-remy to bypass the jaxopt gradients, and everything seems to work fine.
New type of coupling layer based on a sigmoid applied to an affine transformation. Also adds mixture of sigmoids, to provide expressivity. This type of transforms comes from the bgflow code, here.
Annnnnd, finally we get flows that train by score matching \o/ (see notebook) Can probably be further fine-tuned, but at least it shows that we get non-stupid gradients. You can also train it directly by NLL, and woks quite well, a single coupling layer seems to be enough here.
There are a few things we will want to improve, in particular our bijectors here only work in the 2d case, but we can maybe look at that in a separate issue. This code should be enought to start experimenting.