astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

Adds implicit bijector #10

Closed EiffL closed 2 years ago

EiffL commented 2 years ago

This PR adds an implementation of an implicit bijector to be used for defining smooth coupling layers. It is based on the example developed by @b-remy in #7

It's only a draft PR because I have some dimension issues, we need the implicit layer to work for arrays, not just scalars. The bijector works fine with arrays, but it fails when trying to get the gradients for training a normalizing flow. See here for demo using this bijector in a NF.

I couldn't figure out this dimension issues in a reasonable time, so maybe @b-remy you'll have more luck than me, whenever you have some time.

Merging this PR will close #7.

EiffL commented 2 years ago

It looks like the code works \o/ Thanks @b-remy, however it looks like the normalizing flow dies at some point, it is probably related to the issues @Justinezgh is finding

Justinezgh commented 2 years ago

I think there is a pb in the bijector when you compute the det : image

it should be one value per element in your batch right ?

Justinezgh commented 2 years ago

It seems that '_forward_log_det_jacobian' works correctly, the return is not a scalar. So I think the pb come from '_inverse_log_det_jacobian'

EiffL commented 2 years ago

Ah... Hum... We can manually define the inverse log det then, which will just call the inverse function, the forward log det , and put a - sign in front

Justinezgh commented 2 years ago
def _inverse_log_det_jacobian(self, y):
     def logdet_fn(x,a,b,c):
       x = jnp.atleast_1d(x)
       jac = jax.jacobian(self.inv_f, argnums=0)(x,a,b,c)
       s, logdet = jnp.linalg.slogdet(jac)
       return s*logdet
     return jax.vmap(logdet_fn)(y, self.a, self.b, self.c)

It works but this : chain.inverse_log_det_jacobian(x, event_ndims=1, event_shape=(2,)) still return a scalar (even if I don't specify event_ndims and event_shape)

EiffL commented 2 years ago

so this seems to work:

  def _inverse_log_det_jacobian(self, y):
    x = self._inverse(y)
    return - self._forward_log_det_jacobian(x)

image

EiffL commented 2 years ago

If this works, can you commit the modifs to this branch?

Justinezgh commented 2 years ago

Yup it works but when I try this method on this chain :

      chain = tfb.Chain([
            tfb.RealNVP(d//2, bijector_fn=CustomCoupling(name = 'b1')),
            tfb.Permute([1,0]),
            tfb.RealNVP(d//2, bijector_fn=CustomCoupling(name = 'b2')),
            tfb.Permute([1,0]),
        ])

with CustomCoupling :

class CustomCoupling(hk.Module):
  def __call__(self, x, output_units, **condition_kwargs):

    # NN to get a b and c
    net = hk.Linear(128)(x)
    net = jax.nn.leaky_relu(net)
    net = hk.Linear(128)(net)
    net = jax.nn.leaky_relu(net)
    a   = jax.nn.softplus(hk.Linear(output_units)(net))
    b   = jax.nn.sigmoid(hk.Linear(output_units)(net))
    c   = jax.nn.sigmoid(hk.Linear(output_units)(net))

    return  ImplicitRampBijector(lambda x: x**3,a,b,c)

I get just one value

EiffL commented 2 years ago

ah I see, its probably because in the chain x/y ends up being a scalar, because we only have a 2d vector. Our code assumes that the input of the bijector has shape [batch, d] but inside the chain it probably receives [batch] instead of [batch, 1]

Justinezgh commented 2 years ago

ok so I just reshaped the output of _forward_log_det_jacobian, it seems to work but now I have new bugs :D

EiffL commented 2 years ago

And solved! The problem was a stupid shape issue at the output of the logdet image (notebook here) giphy (4)

@Justinezgh can you review this PR :-) ?

Justinezgh commented 2 years ago

I noticed this pb for the numerical inverse image what you changed solves it ?

-> f^-1 is defined on [0.2,0.9] not on [0,1]

EiffL commented 2 years ago

Hummmm I'm not sure what you mean... In your plot above the function looks invertible everywhere on (0,1)

At the end of the notebook on this PR I have a plot where I check that the bijector gets the right inverse. But of course there might be cases I haven't tested

Justinezgh commented 2 years ago

Yes I agree but for some reason on [0,0,2[U]0.9,1] f^-1 return nan (for a = 0.7, b = c = 0.5). But the pb must come from my bijector because I tried with yours and I resized your two moons from [0.2,0.8] to [0.1,0.9] (on x) and there is no problem

EiffL commented 2 years ago

Great :-D if everything seems to work ok, can you review the PR? (doc on how to do that here ;-) ) and then we can merge it :-)

Justinezgh commented 2 years ago

Great :-D if everything seems to work ok, can you review the PR? (doc on how to do that here ;-) ) and then we can merge it :-)

Thanks for the link ^^ And there is a new shape pb when I try to get the nf's score using vmap(), it seems to come from the newton solver :

def newton_solver(f, z_init):
  f_root = lambda z: f(z)
  g = lambda z: z - jnp.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) # PROBLEM 
  return fwd_solver(g, z_init)
ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=() and b=()

I have everything here : https://colab.research.google.com/drive/1diL4HHor7k8tIqdv5PCM-j93Hmr_8eEF?usp=sharing

EiffL commented 2 years ago

Here is how I get the score of the flow with the code that is currently in this branch:

def loss_fn(params, batch, score):
  log_prob, out = jax.vmap(jax.value_and_grad(lambda x: model_NF.apply(params, x.reshape([1,2])).squeeze()))(batch)
  return  jnp.mean( jnp.sum((out - score)**2, axis=1))

(I think ^^' maybe there is another fix I forgot to push)

EiffL commented 2 years ago

I've added the changes you requested @Justinezgh :-) There is now a test that will check we can indeed compute the score of a tiny NF that uses the ramp bijector. See here: https://github.com/astrodeepnet/sbi_experiments/blob/3ab376305b9c84c67625df3357b42eddddca0fa0/tests/test_bijectors.py#L52 Can you do another round of review of this PR? And approve it if you find that it's good enough for now?

EiffL commented 2 years ago

So what do you think ?

Justinezgh commented 2 years ago

yup it works :)