Open EiffL opened 2 years ago
thanks for tagging me ! I think I could easily incorporate this simulator into my IMNN compression framework) to see how SBI shakes out... would be great to have the k -> ell conversion eventually...
Yep, I have played very quickly with training a simple compressor network, but training an IMNN would be nice. The main focus of this project is how to efficiently do the inference part, from "a" compressed statistics.
After additional tinkering, it looks like it kind of works :-) learning the posterior distribution with combination of NLL and score loss. (notebook here )
So for instance, here is the true posterior distribution:
Here is what you can estimate when using only 64 simulations and a NLL loss:
and here is what you get if you use 64 sims and a combination of NLL and score loss: Not yet perfect (it's only 64 sims) but looks like a huge improvement already.
So I think this example should be enough for our purposes at the timescale ~a week.
Did you manage to run the notebook @Justinezgh ? Did you run into any problems?
Yes :) But I don't really understand why you use jax.lax.stop_gradient in the eval_model function when you compute the log p(theta | x, z)
It's because you are evaluating the at the position of the point x that had just been sampled, but as if that was a "data point" coming from outside the model i.e. you don't want the gradients to propagate through x which should remain fixed.
Since p(theta|x) = \int p(theta | x,z)p(z)dz, can I do this to get a 'real' approximation of the posterior :
def get_jointlog_posterior(params, key):
omega_c, s8 = params
cond_model = condition(model, {'omega_c': 0.27, 'sigma_8': 0.77})
cond_model = seed(cond_model, key)
m_data = cond_model()
cond_model = condition(model, {'omega_c': omega_c, 'sigma_8': s8, 'x': m_data})
cond_model = seed(cond_model, key)
model_trace = trace(cond_model).get_trace()
jointlog = jnp.array([model_trace[k]['fn'].log_prob(model_trace[k]['value']) for k in model_trace.keys()]).sum()
return jointlog
get_log_prob = jax.jit(jax.vmap(get_jointlog_posterior, in_axes=[0, None]))
p = jnp.stack(jnp.meshgrid(jnp.linspace(0.2,0.4,64),
jnp.linspace(0.7,0.9,64)), axis=-1)
ps = p.reshape([-1, 2])
true_posterior = jnp.array([np.stack([get_log_prob(params, jax.random.PRNGKey(k)) for params in p], axis=0) for k in range(50)]).mean(axis = 0)
Just to get an idea of what I'm supposed to have
hmm I think this ^^^ is p(x, theta) so I have p(theta|x)*cte ?
yeah it looks ok, but remember that it's an average over the probabilities, not the log probabilities ;-)
So I trained 10 NFs for each nb of simulations (50,100,200,500,1000) and these are the results :
trained on simulations only :
trained on simulations and score/1000 :
maybe the pb can come from the weight on the score ?
and I use the same random key to init all these NFs is that a pb ?
I have the same behavior when I initialize the 10 NFs with different random keys
Ok so
I'm trying with another weight on the score (1500)
I'm doing the same thing with different metrics on the two moons to see if the pb come from the the metric and/or the stochasticity of the wl model
Since we want to approximate the posterior sequentially at the end maybe 50 simulations are enough (and with this our posterior is much better than the one with simulations only). And as we take this posterior to become the new prior it will be easier for the nf to learn the posterior for a given x_0 since the parameter domain is smaller. Meaning that maybe for a set of parameters we will have several samples even if only 50 simulations are used. I don't know ^^
These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.
nll + score/1500
These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.
yes I'll try this
Hummmmm it's still kinda surprising..... how do you compute the NLL in practice?
negative log likelihood (weight = 0) :
def loss_fn(params, batch, score, weight):
y = compressor.apply(params_compressor, batch[1])
log_prob, out = jax.vmap(jax.value_and_grad(lambda p, x: nvp.apply(params, p.reshape([-1,2]), x.reshape([-1,2]))))(batch[0], y) # Here we extract the grad of the model
return -jnp.mean(log_prob) + jnp.mean(jnp.sum((out - score)**2, axis=1))*weight
negative log probability :
# Metric : neg log prob
precision = 1000
oc = jnp.full(precision, 0.27)
s8 = jnp.full(precision, 0.77)
theta = jnp.stack([oc,s8], axis=1)
@jax.jit
def get_truth(theta, key):
oc = theta[0]
s8 = theta[1]
cond_model = condition(seed(model, key), {'omega_c': oc, 'sigma_8': s8})
model_trace = trace(cond_model).get_trace()
return model_trace['x']['value']
# Negative log prob
def neg_log_prob(params, params_compressor, theta, precision):
master_seed = hk.PRNGSequence(1)
x = jax.vmap(get_truth)(theta, jax.random.split(next(master_seed), precision))
y = jax.vmap(lambda x: compressor.apply(params_compressor, x.reshape([-1,64,64])))(x)
return - jnp.mean(jax.vmap((lambda p, x: nvp.apply(params, p.reshape([-1,2]), x.reshape([-1,2]))))(theta, y))
for the two moons with 5 NFs for each nb of simu:
simulations only:
simulations+score/1000:
These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.
Starting with weight = 0 I increase it by 1e-4 every 1000 updates during 10000 updates
ah yeah no I mean, keeping weight constant during a given run, but doing several runs each one with a little bit more score loss.
But this plot is still interesting though... looks like it helps to start from a place where you train by NLL
10 NFs for each nb of simulations (to get the uncertainty) 5 different losses corresponding to the 5 score weights: 0, 5e-4,1e-4,5e-5,1e-5
and to remember the one with simulation+score/1000 is :
Fantastic Justine! Ok ok so, the story seems to be, adding a little bit of score loss helps, especially when we have a limited budget, but too much score loss is detrimental to the NLL.
Is the same thing true when learning just the two moons distribution is my question now
oh sorry sorry, are these plots for the lensing simulator or the two moons?
the lensing simulator
for the two moons I have this :
But I'll do the same thing with different weights to check
did the same thing for the lensing simulator again just to check :
(I tried to do it with the two moons but I didn't take enough simulations to see the convergence at the end so tomorrow :) )
same thing for the two moons (10NFs fpr each nb of simu)
I had a first pass at this and made this simulator, which is not physically correct, but still has kind of the right complexity for what we are interested in.
https://colab.research.google.com/drive/1yTFMRC9v35xo-moIeoVoJQ_YMm6uXprC?usp=sharing
It generate some fake wl maps like this:
and it allows you to compute p(x, z, \theta):
From which you can compute grad_\theta \log p(x, z | theta) i.e. the joint score. I also played a little bit with trying to train a siren to learn the marginal score, but it wasn't a huge huge success so far ....
I'm opening this issue open to discuss refinements to this toy simulation. Currently the following features are missing:
And I'm gonna tag @tlmakinen on this because I stole a bunch of his code ;-P