astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

Implement simple WL field level simulator (e.g. based on https://github.com/tlmakinen/kosmo-kompress ) #18

Open EiffL opened 2 years ago

EiffL commented 2 years ago

I had a first pass at this and made this simulator, which is not physically correct, but still has kind of the right complexity for what we are interested in.

https://colab.research.google.com/drive/1yTFMRC9v35xo-moIeoVoJQ_YMm6uXprC?usp=sharing

It generate some fake wl maps like this: image

and it allows you to compute p(x, z, \theta): image

From which you can compute grad_\theta \log p(x, z | theta) i.e. the joint score. I also played a little bit with trying to train a siren to learn the marginal score, but it wasn't a huge huge success so far ....

I'm opening this issue open to discuss refinements to this toy simulation. Currently the following features are missing:

And I'm gonna tag @tlmakinen on this because I stole a bunch of his code ;-P

tlmakinen commented 2 years ago

thanks for tagging me ! I think I could easily incorporate this simulator into my IMNN compression framework) to see how SBI shakes out... would be great to have the k -> ell conversion eventually...

EiffL commented 2 years ago

Yep, I have played very quickly with training a simple compressor network, but training an IMNN would be nice. The main focus of this project is how to efficiently do the inference part, from "a" compressed statistics.

EiffL commented 2 years ago

After additional tinkering, it looks like it kind of works :-) learning the posterior distribution with combination of NLL and score loss. (notebook here )

So for instance, here is the true posterior distribution: image

Here is what you can estimate when using only 64 simulations and a NLL loss: image

and here is what you get if you use 64 sims and a combination of NLL and score loss: image Not yet perfect (it's only 64 sims) but looks like a huge improvement already.

So I think this example should be enough for our purposes at the timescale ~a week.

EiffL commented 2 years ago

Did you manage to run the notebook @Justinezgh ? Did you run into any problems?

Justinezgh commented 2 years ago

Yes :) But I don't really understand why you use jax.lax.stop_gradient in the eval_model function when you compute the log p(theta | x, z)

EiffL commented 2 years ago

It's because you are evaluating the at the position of the point x that had just been sampled, but as if that was a "data point" coming from outside the model i.e. you don't want the gradients to propagate through x which should remain fixed.

Justinezgh commented 2 years ago

Since p(theta|x) = \int p(theta | x,z)p(z)dz, can I do this to get a 'real' approximation of the posterior :

def get_jointlog_posterior(params, key):
    omega_c, s8 = params

    cond_model = condition(model, {'omega_c': 0.27, 'sigma_8': 0.77})
    cond_model = seed(cond_model, key)
    m_data = cond_model()

    cond_model = condition(model, {'omega_c': omega_c, 'sigma_8': s8, 'x': m_data})
    cond_model = seed(cond_model, key)
    model_trace = trace(cond_model).get_trace()
    jointlog = jnp.array([model_trace[k]['fn'].log_prob(model_trace[k]['value']) for k in model_trace.keys()]).sum()
    return jointlog

get_log_prob = jax.jit(jax.vmap(get_jointlog_posterior, in_axes=[0, None]))

p = jnp.stack(jnp.meshgrid(jnp.linspace(0.2,0.4,64),
                 jnp.linspace(0.7,0.9,64)), axis=-1)
ps = p.reshape([-1, 2])

true_posterior = jnp.array([np.stack([get_log_prob(params, jax.random.PRNGKey(k)) for params in p], axis=0) for k in range(50)]).mean(axis = 0)

image

Just to get an idea of what I'm supposed to have

Justinezgh commented 2 years ago

hmm I think this ^^^ is p(x, theta) so I have p(theta|x)*cte ?

EiffL commented 2 years ago

yeah it looks ok, but remember that it's an average over the probabilities, not the log probabilities ;-)

Justinezgh commented 2 years ago

So I trained 10 NFs for each nb of simulations (50,100,200,500,1000) and these are the results :

image

trained on simulations only :

ezgif-5-2d7e7b129d

trained on simulations and score/1000 :

ezgif-5-0b4635d982

maybe the pb can come from the weight on the score ?

and I use the same random key to init all these NFs is that a pb ?

Justinezgh commented 2 years ago

I have the same behavior when I initialize the 10 NFs with different random keys Unknown-12

Justinezgh commented 2 years ago

Ok so

  1. I'm trying with another weight on the score (1500)

  2. I'm doing the same thing with different metrics on the two moons to see if the pb come from the the metric and/or the stochasticity of the wl model

  3. Since we want to approximate the posterior sequentially at the end maybe 50 simulations are enough (and with this our posterior is much better than the one with simulations only). And as we take this posterior to become the new prior it will be easier for the nf to learn the posterior for a given x_0 since the parameter domain is smaller. Meaning that maybe for a set of parameters we will have several samples even if only 50 simulations are used. I don't know ^^

EiffL commented 2 years ago

These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.

Justinezgh commented 2 years ago

nll + score/1500 image

Justinezgh commented 2 years ago

These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.

yes I'll try this

EiffL commented 2 years ago

Hummmmm it's still kinda surprising..... how do you compute the NLL in practice?

Justinezgh commented 2 years ago

negative log likelihood (weight = 0) :

def loss_fn(params, batch, score, weight):
  y = compressor.apply(params_compressor, batch[1])
  log_prob, out = jax.vmap(jax.value_and_grad(lambda p, x: nvp.apply(params, p.reshape([-1,2]), x.reshape([-1,2]))))(batch[0], y) # Here we extract the grad of the model
  return -jnp.mean(log_prob) + jnp.mean(jnp.sum((out - score)**2, axis=1))*weight 
Justinezgh commented 2 years ago

negative log probability :

# Metric : neg log prob 
precision = 1000
oc = jnp.full(precision, 0.27)
s8 = jnp.full(precision, 0.77)
theta = jnp.stack([oc,s8], axis=1)

@jax.jit
def get_truth(theta, key):

  oc = theta[0]
  s8 = theta[1]

  cond_model = condition(seed(model, key), {'omega_c': oc, 'sigma_8': s8})
  model_trace = trace(cond_model).get_trace()
  return model_trace['x']['value']

# Negative log prob 
def neg_log_prob(params, params_compressor, theta, precision):

  master_seed = hk.PRNGSequence(1)
  x = jax.vmap(get_truth)(theta, jax.random.split(next(master_seed), precision))
  y = jax.vmap(lambda x: compressor.apply(params_compressor, x.reshape([-1,64,64])))(x)

  return - jnp.mean(jax.vmap((lambda p, x: nvp.apply(params, p.reshape([-1,2]), x.reshape([-1,2]))))(theta, y))
Justinezgh commented 2 years ago

for the two moons with 5 NFs for each nb of simu:

image image image imagesimulations only: nllfinal

simulations+score/1000: nllscorefinal

Justinezgh commented 2 years ago

These plots are really awesome, and yeah I think the first thing to try from here is to try adding starting from no score and doing a few runs with a increasing amount of score loss.

Starting with weight = 0 I increase it by 1e-4 every 1000 updates during 10000 updates image

EiffL commented 2 years ago

ah yeah no I mean, keeping weight constant during a given run, but doing several runs each one with a little bit more score loss.

But this plot is still interesting though... looks like it helps to start from a place where you train by NLL

Justinezgh commented 2 years ago

10 NFs for each nb of simulations (to get the uncertainty) 5 different losses corresponding to the 5 score weights: 0, 5e-4,1e-4,5e-5,1e-5 image

image

image

image

image

and to remember the one with simulation+score/1000 is :

EiffL commented 2 years ago

Fantastic Justine! Ok ok so, the story seems to be, adding a little bit of score loss helps, especially when we have a limited budget, but too much score loss is detrimental to the NLL.

Is the same thing true when learning just the two moons distribution is my question now

EiffL commented 2 years ago

oh sorry sorry, are these plots for the lensing simulator or the two moons?

Justinezgh commented 2 years ago

the lensing simulator

Justinezgh commented 2 years ago

for the two moons I have this :

But I'll do the same thing with different weights to check

Justinezgh commented 2 years ago

did the same thing for the lensing simulator again just to check : image image image image image image

(I tried to do it with the two moons but I didn't take enough simulations to see the convergence at the end so tomorrow :) )

Justinezgh commented 2 years ago

same thing for the two moons (10NFs fpr each nb of simu)

image image image image image image