Closed EiffL closed 2 years ago
It looks like the code works \o/ Thanks @b-remy, however it looks like the normalizing flow dies at some point, it is probably related to the issues @Justinezgh is finding
I think there is a pb in the bijector when you compute the det :
it should be one value per element in your batch right ?
It seems that '_forward_log_det_jacobian' works correctly, the return is not a scalar. So I think the pb come from '_inverse_log_det_jacobian'
Ah... Hum... We can manually define the inverse log det then, which will just call the inverse function, the forward log det , and put a - sign in front
def _inverse_log_det_jacobian(self, y):
def logdet_fn(x,a,b,c):
x = jnp.atleast_1d(x)
jac = jax.jacobian(self.inv_f, argnums=0)(x,a,b,c)
s, logdet = jnp.linalg.slogdet(jac)
return s*logdet
return jax.vmap(logdet_fn)(y, self.a, self.b, self.c)
It works but this : chain.inverse_log_det_jacobian(x, event_ndims=1, event_shape=(2,)) still return a scalar (even if I don't specify event_ndims and event_shape)
so this seems to work:
def _inverse_log_det_jacobian(self, y):
x = self._inverse(y)
return - self._forward_log_det_jacobian(x)
If this works, can you commit the modifs to this branch?
Yup it works but when I try this method on this chain :
chain = tfb.Chain([
tfb.RealNVP(d//2, bijector_fn=CustomCoupling(name = 'b1')),
tfb.Permute([1,0]),
tfb.RealNVP(d//2, bijector_fn=CustomCoupling(name = 'b2')),
tfb.Permute([1,0]),
])
with CustomCoupling :
class CustomCoupling(hk.Module):
def __call__(self, x, output_units, **condition_kwargs):
# NN to get a b and c
net = hk.Linear(128)(x)
net = jax.nn.leaky_relu(net)
net = hk.Linear(128)(net)
net = jax.nn.leaky_relu(net)
a = jax.nn.softplus(hk.Linear(output_units)(net))
b = jax.nn.sigmoid(hk.Linear(output_units)(net))
c = jax.nn.sigmoid(hk.Linear(output_units)(net))
return ImplicitRampBijector(lambda x: x**3,a,b,c)
I get just one value
ah I see, its probably because in the chain x/y ends up being a scalar, because we only have a 2d vector. Our code assumes that the input of the bijector has shape [batch, d] but inside the chain it probably receives [batch] instead of [batch, 1]
ok so I just reshaped the output of _forward_log_det_jacobian, it seems to work but now I have new bugs :D
And solved! The problem was a stupid shape issue at the output of the logdet (notebook here)
@Justinezgh can you review this PR :-) ?
I noticed this pb for the numerical inverse what you changed solves it ?
-> f^-1 is defined on [0.2,0.9] not on [0,1]
Hummmm I'm not sure what you mean... In your plot above the function looks invertible everywhere on (0,1)
At the end of the notebook on this PR I have a plot where I check that the bijector gets the right inverse. But of course there might be cases I haven't tested
Yes I agree but for some reason on [0,0,2[U]0.9,1] f^-1 return nan (for a = 0.7, b = c = 0.5). But the pb must come from my bijector because I tried with yours and I resized your two moons from [0.2,0.8] to [0.1,0.9] (on x) and there is no problem
Great :-D if everything seems to work ok, can you review the PR? (doc on how to do that here ;-) ) and then we can merge it :-)
Great :-D if everything seems to work ok, can you review the PR? (doc on how to do that here ;-) ) and then we can merge it :-)
Thanks for the link ^^ And there is a new shape pb when I try to get the nf's score using vmap(), it seems to come from the newton solver :
def newton_solver(f, z_init):
f_root = lambda z: f(z)
g = lambda z: z - jnp.linalg.solve(jax.jacobian(f_root)(z), f_root(z)) # PROBLEM
return fwd_solver(g, z_init)
ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=() and b=()
I have everything here : https://colab.research.google.com/drive/1diL4HHor7k8tIqdv5PCM-j93Hmr_8eEF?usp=sharing
Here is how I get the score of the flow with the code that is currently in this branch:
def loss_fn(params, batch, score):
log_prob, out = jax.vmap(jax.value_and_grad(lambda x: model_NF.apply(params, x.reshape([1,2])).squeeze()))(batch)
return jnp.mean( jnp.sum((out - score)**2, axis=1))
(I think ^^' maybe there is another fix I forgot to push)
I've added the changes you requested @Justinezgh :-) There is now a test that will check we can indeed compute the score of a tiny NF that uses the ramp bijector. See here: https://github.com/astrodeepnet/sbi_experiments/blob/3ab376305b9c84c67625df3357b42eddddca0fa0/tests/test_bijectors.py#L52 Can you do another round of review of this PR? And approve it if you find that it's good enough for now?
So what do you think ?
yup it works :)
This PR adds an implementation of an implicit bijector to be used for defining smooth coupling layers. It is based on the example developed by @b-remy in #7
It's only a draft PR because I have some dimension issues, we need the implicit layer to work for arrays, not just scalars. The bijector works fine with arrays, but it fails when trying to get the gradients for training a normalizing flow. See here for demo using this bijector in a NF.
I couldn't figure out this dimension issues in a reasonable time, so maybe @b-remy you'll have more luck than me, whenever you have some time.
Merging this PR will close #7.