astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

Learning the two-moons distribution with a normalizing flow #2

Open EiffL opened 2 years ago

EiffL commented 2 years ago

@Justinezgh here are some examples I have lying around of building a nomalizing flow in jax, and training it on the two moons distribution:

So I would say, you can try to rewrite a small notebook, using 1 as an example for how to generate examples from the two moons dataset, and 2 for an example of a slightly better implementation using haiku

Learning objectives:

EiffL commented 2 years ago

Ok well, turns out I had a colab notebook with everything in one place: https://colab.research.google.com/drive/1HPom85QIjugHaL2RkO-5TWle6ZeoVBWC?usp=sharing

Can you see if it is working for you? and if so, can you add your version of this to this repo?

EiffL commented 2 years ago

Ah, and instead of using the sklearn two moons dataset, you can use the pure TFP one from this notebook: https://colab.research.google.com/drive/1yRsh1Kmb6O1J6Rx3v1hX7-oS9cQUyGiM?usp=sharing

The advantage is that it will also allow you to compute gradients ;-)

image

Justinezgh commented 2 years ago

Ok thanks a lot ! I will look into all of this :)

Justinezgh commented 2 years ago

Learning the two moons from tensorflow using RealNVP https://colab.research.google.com/drive/1E2o54mt8KHlnWkwJCaEpzBunmTR3NmWC?usp=sharing

Justinezgh commented 2 years ago

Learning the two moons from tensorflow using RealNVP + using the score https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing

Justinezgh commented 2 years ago

I can't use @jax.jit for the get_batch function (from this notebook : https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing ), when I use it I get this error : 'IndexError: tuple index out of range'

EiffL commented 2 years ago

Could you try the following?

@jax.jit
def get_batch(batch_size, seed):
  batch = get_two_moons(sigma= 0.05).sample(batch_size, seed=seed)
  score = jax.vmap(jax.grad(two_moons.log_prob))(batch)
  return batch, score

Maybe it's an issue coming from the fact that you build the distribution outside of the jitted function

EiffL commented 2 years ago

So, @Justinezgh, I think you are already pretty much all setup to start some fun research and experiments, so I want to show you some preliminary work on this stuff we did with @b-remy last year.

We were testing a technique called denoising score matching to learn the score field (not the distribution itself), and we did some tests against what a conventional Normalizing Flow could achieve. Here is a relevant plot: image (from this notebook: https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NF-DAE-SN-comparison.ipynb)

It shows that when training a Normalizing Flow just for density estimation, the score field can go all wonky. Also, if you think about the change of variable formula in a Normalizing Flow, the score will have two terms, one that comes from the inverse mapping, and one that comes from the Jacobian determinant. For a RealNVP, @b-remy also made this plot: image (https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NFlows_where_come_from_the_failures.ipynb) which shows that the determinant part seems be responsible for most of the bad behavior, it probably implies that the particular shape of the RealNVP determinant is not very regular.

EiffL commented 2 years ago

This makes me think we can take as a first angle of attack is to check that for a given choice of normalizing flow architecture, the log density is indeed correctly continuously differentiable. And thinking about the log determinant term is probably a good idea.

You can also have a look at one of the seminal papers on score matching: https://www.cs.helsinki.fi/u/ahyvarin/papers/JMLR05.pdf

Justinezgh commented 2 years ago

Impact of the nb of coupling layers (affine coupling layers) on the score field : https://colab.research.google.com/drive/1H0Q_hgb0Yjtqvyg9RKeqTvt5lSZBNiap?usp=sharing

Justinezgh commented 2 years ago

Same but with Neural Spline Flows : https://colab.research.google.com/drive/1IFDmsNUTsHIjQpjnXKIAG3PUeyx6NLux?usp=sharing

And I still have problems with @jax.jit

EiffL commented 2 years ago
batch_size=512

@jax.jit
def get_batch(seed):
  two_moons = get_two_moons(sigma= 0.05)
  batch = two_moons.sample(batch_size, seed=seed)
  return batch

simple fix for the jitting of of get_batch, removing the batch_size argument

EiffL commented 2 years ago

@Justinezgh this is all super interesting. Two questions:

b-remy commented 2 years ago

Hi @Justinezgh , note that if it is more convenient, you can also keep the batch_size argument by specifying to jit that it is a static argument.

import functools

@functools.partial(jax.jit, static_argnums=(1,))
def get_batch(seed, batch_size):
  two_moons = get_two_moons(sigma= 0.05)
  batch = two_moons.sample(batch_size, seed=seed)
  return batch
EiffL commented 2 years ago

Sorry, I was too curious.... I quickly tried to train a regression network under a score matching loss to make sure things were not crazy. And it seems to work pretty well: image

The loss function nicely goes to zero instead of jumping around as in the NF examples. image

Training the grads of the NN

For fun I also tried to train the same network, but making it output just a scalar and constraining its gradients, and that doesnt train at all: image

image

Training the grads of the NN with a C\infty neural network

And for even more fun, I tried training the same model again by constraining the grads, replacing relu activation by a sin function, as proposed in https://arxiv.org/abs/2006.09661

And BAM! By magic it works \o/ image (note: for the background on the right I use exp( scalar output of the network ))

And training goes super easily: image

=> All codes available here: https://github.com/astrodeepnet/sbi_experiments/pull/4

Justinezgh commented 2 years ago

https://colab.research.google.com/drive/1OnL56FPKzinJrnL16xFdYXcBiKSPsOmy?usp=sharing :)

EiffL commented 2 years ago

Ahaha yep ^^ sorry, this had been bugging me all afternoon and was dying to try, it's pretty fun stuff :-)

EiffL commented 2 years ago

So, the next logical step is to build a NF that is by construction Cinfty.

@b-remy reminded me of this paper: https://arxiv.org/pdf/2110.00351.pdf where they actually propose a coupling layer that should be continuously differentiable, to place in a RealNVP. Probably worthwhile to take a look.

Justinezgh commented 2 years ago

So just to see, I tried to use the sin activation function for the NN of the affine coupling layer :

image

For the NF with 3 coupling layers :

image

Notebook : https://colab.research.google.com/drive/1ZU-w76vJ81-PArB9vr1x9fi7qpZ1AOnu?usp=sharing

EiffL commented 2 years ago

interesting interesting yeah, it doesn't seem to help directly :-/

So here is what they say in section 4 of (2110.00351): image image

So what they are saying is that with an affine coupling layer, you lose expressivity in the gradients of the log p. And they also say that the Neural Spline Flows have poor gradients because they are only C1.

So I think we could try the following: Using a C\infty coupling layer, and training under the Score Matching loss (because whether or not the flow can train under the SM loss on its own will tell us if the model is well adapted).

Alternatively, it might be possible to use a MAF instead of a RealNVP, because it's possible that if the masked autoencoder in the MAF layer is Cinfty, so will be the flow

EiffL commented 2 years ago

(just for the record, what I said there about MAF was stupid, you still have an Affine Coupling with a MAF)

Justinezgh commented 2 years ago

Just to be sure, this is this function that we want to place in a RealNVP ?

image

If yes, do we want to use f to define the shift and the scale part ? Because both shift and scale are R^d -> R^(D-d) so do we have to do some kind of projection for f(x) ? Like we define f(x_i) := (1-c).((g(x_i)-g(0))/(...)) + c.x_i so f(x) \in R^d and then we project in R^(D-d) ?

Actually I'm not sure that f was made to be used in a RealNVP, idk..

EiffL commented 2 years ago

That's a good question. And yes that's the coupling we might want to use, instead of an affine coupling.

So you don't generate shift and scale parameters, instead you generate these a,b,c parameters which are the outputs of some neural network which takes R^d inputs and return R^(D-d) outputs, and the function g is a bijection in R^(D-d).

You can have a look at how the Spline flows work, it's a bit different, but illustrates how a parametrisation can deviate from affine.

EiffL commented 2 years ago

It may not be 100% trivial because I think you would have to define a TFP bijector to implement the mapping f. It shouldn't be too difficult, but will take a bit of coding.

Ah, and there is another approach we could take i think.... We could use a ffjord, and there is one easily usable in the TF version of TFP (so not in Jax unfortunately). I think if the ode function is sufficiently smooth, so is the ode flow.

b-remy commented 2 years ago

+1 I was also thinking that Continuous Normalizing Flows (the flow of transformations being continuous here) such as Neural ODE (1806.07366) or FFJORD (1810.01367) would be an interesting approach to look at in parallel!

Justinezgh commented 2 years ago

I'm not sure that the function f(x) = (1-c)((g(x)-g(0))/(g(1)-g(0)))+cx has an analytical inverse. At leat for rho(x) = exp(-1/alpha*x**beta)

EiffL commented 2 years ago

hummmmmmmmmmmmm that sounds surprising

EiffL commented 2 years ago

ok, maybe the exp is hard to find an analytical inv ^^' the monomial should be easier, and otherwise we could impllement a generic purpose inverse function, with gradients computed by the implici function theorem.

EiffL commented 2 years ago

;-) wink wink @b-remy

b-remy commented 2 years ago

I've been looking at ffjord, and we can indeed observe that working with a Continuous Normalizing Flow, which makes smooth transformations, yields a smoother score function!

image https://colab.research.google.com/drive/1nCs0UH8CfToW6L4ZNehzERBdIx84Eg6k?usp=sharing

Here I used maximum likelihood only, no score matching loss because I have not figured out how to implement it with tensorflow yet...

Maybe we should open a specific issue dedicated to ODE flows, to discuss different loss functions or how the gradients are actually computed. And maybe consider implementing a JAX version because taking gradients, or computing vjp, is not as easy in TF :-)

EiffL commented 2 years ago

Yep @b-remy agreed, we can open a separate issue to discuss using an ODE flow for this :-) We can use this as a plan B, if plan A of using custom coupling layer doesn't work.

@Justinezgh do you have some news on building an invertible coupling? If not analytically possible, we can use an implicit function trick to define the gradients of a numerical inverse. @b-remy already has experience with this, it's a little bit more involved, but if we don;t analytic inverses it should work.

Justinezgh commented 2 years ago

I think the best I can do is rho(x) = x**2 :/

https://colab.research.google.com/drive/1kRA4ReFryVqFJfLxwtL7nXg-Uwsn1sal?usp=sharing

( I can't specify the domain if a,b and c are symbols)

I was trying to compute f^-1 as a function of x,a,b,c in order to use it directly in the bijector but I didn't manage to do it and I don't think that sympy is jaxssifiable -> Ok I think I just managed to do it

Ok so the best I can do is rho(x) = x**3

EiffL commented 2 years ago

\o/ x^3 should work for our purposes! and maybe x^2 is actually enough... we just need one more order of smoothness than the typical affine coupling.

Let's see what this gives us in practice in a bijector :-)

EiffL commented 2 years ago

Really awesome that you used sympy for solving this BTW!

EiffL commented 2 years ago

Any luck with implementing a bijector? Don't hesitate if you have questions ;-)

Justinezgh commented 2 years ago

I have "some" bugs :D

https://colab.research.google.com/drive/1cmtlXbH-xX7s7m7MtiL4DWyD_UriSoIg?usp=sharing

When I try to train the NF I have this error (1024 is the batch size) : 'ValueError: The arguments to _cofactor_solve must have shapes a=[..., m, m] and b=[..., m, m]; got a=(1024, 1, 1024, 1) and b=(1024, 1, 1024, 1)'

So I tried with batch_size = 1 and I noticed that the loss is NaN. I tried the same thing with an easier bijector Exp() and I have the same pb for the loss. So I tried to print() everything in the NN to get a,b,c and for some reason the initialization part fails

EiffL commented 2 years ago

so several things,

EiffL commented 2 years ago

did it help ^^' ?

Justinezgh commented 2 years ago

yup ! I computed the gradients with Sympy : https://colab.research.google.com/drive/1URrqY8TVf0EbtO2DHpjqEnR4jIvs2j-P?usp=sharing I don't know if it's faster to have the Jacobian for both f and f^-1 or to use the fact that forward_log_det_jacobian is the negative of inverse_log_det_jacobian, evaluated at f^{-1}(y).

And so now I'm dealing with a new pb :D image Just have to find a way to have x \in [-1/2a+b, 1/2a+b]

EiffL commented 2 years ago

This looks good Justine, but I didn't quite get your point about x in a given range... To keep things simple for now, can we define a flow that remains between (0,1) ?