Open EiffL opened 2 years ago
Ok well, turns out I had a colab notebook with everything in one place: https://colab.research.google.com/drive/1HPom85QIjugHaL2RkO-5TWle6ZeoVBWC?usp=sharing
Can you see if it is working for you? and if so, can you add your version of this to this repo?
Ah, and instead of using the sklearn two moons dataset, you can use the pure TFP one from this notebook: https://colab.research.google.com/drive/1yRsh1Kmb6O1J6Rx3v1hX7-oS9cQUyGiM?usp=sharing
The advantage is that it will also allow you to compute gradients ;-)
Ok thanks a lot ! I will look into all of this :)
Learning the two moons from tensorflow using RealNVP https://colab.research.google.com/drive/1E2o54mt8KHlnWkwJCaEpzBunmTR3NmWC?usp=sharing
Learning the two moons from tensorflow using RealNVP + using the score https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing
I can't use @jax.jit for the get_batch function (from this notebook : https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing ), when I use it I get this error : 'IndexError: tuple index out of range'
Could you try the following?
@jax.jit
def get_batch(batch_size, seed):
batch = get_two_moons(sigma= 0.05).sample(batch_size, seed=seed)
score = jax.vmap(jax.grad(two_moons.log_prob))(batch)
return batch, score
Maybe it's an issue coming from the fact that you build the distribution outside of the jitted function
So, @Justinezgh, I think you are already pretty much all setup to start some fun research and experiments, so I want to show you some preliminary work on this stuff we did with @b-remy last year.
We were testing a technique called denoising score matching to learn the score field (not the distribution itself), and we did some tests against what a conventional Normalizing Flow could achieve. Here is a relevant plot: (from this notebook: https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NF-DAE-SN-comparison.ipynb)
It shows that when training a Normalizing Flow just for density estimation, the score field can go all wonky. Also, if you think about the change of variable formula in a Normalizing Flow, the score will have two terms, one that comes from the inverse mapping, and one that comes from the Jacobian determinant. For a RealNVP, @b-remy also made this plot: (https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NFlows_where_come_from_the_failures.ipynb) which shows that the determinant part seems be responsible for most of the bad behavior, it probably implies that the particular shape of the RealNVP determinant is not very regular.
This makes me think we can take as a first angle of attack is to check that for a given choice of normalizing flow architecture, the log density is indeed correctly continuously differentiable. And thinking about the log determinant term is probably a good idea.
You can also have a look at one of the seminal papers on score matching: https://www.cs.helsinki.fi/u/ahyvarin/papers/JMLR05.pdf
Impact of the nb of coupling layers (affine coupling layers) on the score field : https://colab.research.google.com/drive/1H0Q_hgb0Yjtqvyg9RKeqTvt5lSZBNiap?usp=sharing
Same but with Neural Spline Flows : https://colab.research.google.com/drive/1IFDmsNUTsHIjQpjnXKIAG3PUeyx6NLux?usp=sharing
And I still have problems with @jax.jit
jit
, try using static_argnums
or applying jit
to smaller subfunctions.'batch_size=512
@jax.jit
def get_batch(seed):
two_moons = get_two_moons(sigma= 0.05)
batch = two_moons.sample(batch_size, seed=seed)
return batch
simple fix for the jitting of of get_batch, removing the batch_size argument
@Justinezgh this is all super interesting. Two questions:
[ ] Have you checked (at least theoretically) that the log prob of a normalizing flow using a realNVP is at least twice differentiable? The leaky-relu for instance shouldn't be, and I think it actually has 0 second order gradients (so 0 gradients of the score) almost everywhere. Which could explain why we are having difficulties training on the score.
[ ] Can you try to learn the score field with a simple regression network instead of a Normalizing Flow, so directly training a function s_\theta(x) to learn the score field, with a dense neural network for instance. If this works well, it means that there is nothing in principle wrong with the score matching loss, and that if there are difficulties, they must come from the particular architecture of the Normalizing Flow.
Hi @Justinezgh , note that if it is more convenient, you can also keep the batch_size
argument by specifying to jit
that it is a static argument.
import functools
@functools.partial(jax.jit, static_argnums=(1,))
def get_batch(seed, batch_size):
two_moons = get_two_moons(sigma= 0.05)
batch = two_moons.sample(batch_size, seed=seed)
return batch
Sorry, I was too curious.... I quickly tried to train a regression network under a score matching loss to make sure things were not crazy. And it seems to work pretty well:
The loss function nicely goes to zero instead of jumping around as in the NF examples.
For fun I also tried to train the same network, but making it output just a scalar and constraining its gradients, and that doesnt train at all:
And for even more fun, I tried training the same model again by constraining the grads, replacing relu activation by a sin
function, as proposed in https://arxiv.org/abs/2006.09661
And BAM! By magic it works \o/ (note: for the background on the right I use exp( scalar output of the network ))
And training goes super easily:
=> All codes available here: https://github.com/astrodeepnet/sbi_experiments/pull/4
Ahaha yep ^^ sorry, this had been bugging me all afternoon and was dying to try, it's pretty fun stuff :-)
So, the next logical step is to build a NF that is by construction Cinfty.
@b-remy reminded me of this paper: https://arxiv.org/pdf/2110.00351.pdf where they actually propose a coupling layer that should be continuously differentiable, to place in a RealNVP. Probably worthwhile to take a look.
So just to see, I tried to use the sin activation function for the NN of the affine coupling layer :
For the NF with 3 coupling layers :
Notebook : https://colab.research.google.com/drive/1ZU-w76vJ81-PArB9vr1x9fi7qpZ1AOnu?usp=sharing
interesting interesting yeah, it doesn't seem to help directly :-/
So here is what they say in section 4 of (2110.00351):
So what they are saying is that with an affine coupling layer, you lose expressivity in the gradients of the log p. And they also say that the Neural Spline Flows have poor gradients because they are only C1.
So I think we could try the following: Using a C\infty coupling layer, and training under the Score Matching loss (because whether or not the flow can train under the SM loss on its own will tell us if the model is well adapted).
Alternatively, it might be possible to use a MAF instead of a RealNVP, because it's possible that if the masked autoencoder in the MAF layer is Cinfty, so will be the flow
(just for the record, what I said there about MAF was stupid, you still have an Affine Coupling with a MAF)
Just to be sure, this is this function that we want to place in a RealNVP ?
If yes, do we want to use f to define the shift and the scale part ? Because both shift and scale are R^d -> R^(D-d) so do we have to do some kind of projection for f(x) ? Like we define f(x_i) := (1-c).((g(x_i)-g(0))/(...)) + c.x_i so f(x) \in R^d and then we project in R^(D-d) ?
Actually I'm not sure that f was made to be used in a RealNVP, idk..
That's a good question. And yes that's the coupling we might want to use, instead of an affine coupling.
So you don't generate shift and scale parameters, instead you generate these a,b,c parameters which are the outputs of some neural network which takes R^d inputs and return R^(D-d) outputs, and the function g is a bijection in R^(D-d).
You can have a look at how the Spline flows work, it's a bit different, but illustrates how a parametrisation can deviate from affine.
It may not be 100% trivial because I think you would have to define a TFP bijector to implement the mapping f. It shouldn't be too difficult, but will take a bit of coding.
Ah, and there is another approach we could take i think.... We could use a ffjord, and there is one easily usable in the TF version of TFP (so not in Jax unfortunately). I think if the ode function is sufficiently smooth, so is the ode flow.
+1 I was also thinking that Continuous Normalizing Flows (the flow of transformations being continuous here) such as Neural ODE (1806.07366) or FFJORD (1810.01367) would be an interesting approach to look at in parallel!
I'm not sure that the function f(x) = (1-c)((g(x)-g(0))/(g(1)-g(0)))+cx has an analytical inverse. At leat for rho(x) = exp(-1/alpha*x**beta)
hummmmmmmmmmmmm that sounds surprising
ok, maybe the exp is hard to find an analytical inv ^^' the monomial should be easier, and otherwise we could impllement a generic purpose inverse function, with gradients computed by the implici function theorem.
;-) wink wink @b-remy
I've been looking at ffjord, and we can indeed observe that working with a Continuous Normalizing Flow, which makes smooth transformations, yields a smoother score function!
https://colab.research.google.com/drive/1nCs0UH8CfToW6L4ZNehzERBdIx84Eg6k?usp=sharing
Here I used maximum likelihood only, no score matching loss because I have not figured out how to implement it with tensorflow yet...
Maybe we should open a specific issue dedicated to ODE flows, to discuss different loss functions or how the gradients are actually computed. And maybe consider implementing a JAX version because taking gradients, or computing vjp, is not as easy in TF :-)
Yep @b-remy agreed, we can open a separate issue to discuss using an ODE flow for this :-) We can use this as a plan B, if plan A of using custom coupling layer doesn't work.
@Justinezgh do you have some news on building an invertible coupling? If not analytically possible, we can use an implicit function trick to define the gradients of a numerical inverse. @b-remy already has experience with this, it's a little bit more involved, but if we don;t analytic inverses it should work.
I think the best I can do is rho(x) = x**2 :/
https://colab.research.google.com/drive/1kRA4ReFryVqFJfLxwtL7nXg-Uwsn1sal?usp=sharing
( I can't specify the domain if a,b and c are symbols)
I was trying to compute f^-1 as a function of x,a,b,c in order to use it directly in the bijector but I didn't manage to do it and I don't think that sympy is jaxssifiable -> Ok I think I just managed to do it
Ok so the best I can do is rho(x) = x**3
\o/ x^3 should work for our purposes! and maybe x^2 is actually enough... we just need one more order of smoothness than the typical affine coupling.
Let's see what this gives us in practice in a bijector :-)
Really awesome that you used sympy for solving this BTW!
Any luck with implementing a bijector? Don't hesitate if you have questions ;-)
I have "some" bugs :D
https://colab.research.google.com/drive/1cmtlXbH-xX7s7m7MtiL4DWyD_UriSoIg?usp=sharing
When I try to train the NF I have this error (1024 is the batch size) : 'ValueError: The arguments to _cofactor_solve must have shapes a=[..., m, m] and b=[..., m, m]; got a=(1024, 1, 1024, 1) and b=(1024, 1, 1024, 1)'
So I tried with batch_size = 1 and I noticed that the loss is NaN. I tried the same thing with an easier bijector Exp() and I have the same pb for the loss. So I tried to print() everything in the NN to get a,b,c and for some reason the initialization part fails
so several things,
jnp.log(jnp.abs(jnp.linalg.det(jax.jacfwd(f, argnums = 0)(x,self.a, self.b, self.c))))
is there an analytic jacobian for this bijector ?did it help ^^' ?
yup ! I computed the gradients with Sympy : https://colab.research.google.com/drive/1URrqY8TVf0EbtO2DHpjqEnR4jIvs2j-P?usp=sharing I don't know if it's faster to have the Jacobian for both f and f^-1 or to use the fact that forward_log_det_jacobian is the negative of inverse_log_det_jacobian, evaluated at f^{-1}(y).
And so now I'm dealing with a new pb :D Just have to find a way to have x \in [-1/2a+b, 1/2a+b]
This looks good Justine, but I didn't quite get your point about x in a given range... To keep things simple for now, can we define a flow that remains between (0,1) ?
@Justinezgh here are some examples I have lying around of building a nomalizing flow in jax, and training it on the two moons distribution:
[1] full implementation of a NF in JAX+flax but it is kind of outdated: https://github.com/EiffL/jax-nf (see this notebook in particular https://github.com/EiffL/jax-nf/blob/master/notebooks/Vanilla-NVP.ipynb)
[2] notebook with a NF implementaion in JAX+haiku: https://github.com/EiffL/Quarks2CosmosDataChallenge/blob/main/notebooks/PartII-GenerativeModels-Solution.ipynb (see
Step III: Latent Normalizing Flow
and ignore all the rest, it just shows you how to build a NF with jax and haiku)So I would say, you can try to rewrite a small notebook, using 1 as an example for how to generate examples from the two moons dataset, and 2 for an example of a slightly better implementation using haiku
Learning objectives: