astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

Decide on metrics to compare performance of NDEs #17

Open EiffL opened 2 years ago

EiffL commented 2 years ago

I'm opening this issue to discuss which metrics we want to use, and some first comparison plots.

@Justinezgh have you looked at a few metrics we could use? Even for now on our two moons distribution, we can compare different architectures for the flow in terms of these metrics, and check that they indeed translate into good measures of whether or not the distribution looks ok.

Justinezgh commented 2 years ago

In the paper Benchmarking Simulation-Based Inference they propose different metrics depending on the 'true information' and the algorithm used for the approximation.

If we only have the 'true' observations and we assume that we has access to the ground truth parameters we can use :

If we can sample from the ground truth posterior :

If we have access to the gradients of the ground truth posterior :

If we have access to the ground truth posterior :

image with q tilde the unnormalized posterior approximation and q the normalized one.

Justinezgh commented 2 years ago

If we only want to test different architectures maybe the f-divergences like the kl divergence is the best since we have access to the true distribution and the approximate one. And maybe the KSD one since we consider the gradients (but I'm not used to it at all :D)

EiffL commented 2 years ago

Yeah so I think in our toy experiments with analytic distributions like 2 moons, we could use all of them, and would be interesting to see how much an improvement of these metrics translates into a qualitative improvement of the posterior.

In practical cosmology applications, we'll see, it depends on the particular scenario we are considering. We could still be able to obtain samples and gradients of the ground truth, but expensive to obtain (we need to run an HMC).

From practical considerations, let's start with one, whichever one is simplest to implement. I guess all the codes can be found in Jan-Matthis' library, but it's all in pytorch ^^'

Justinezgh commented 2 years ago

Just to be sure, what we want to do first is to use all these metrics on different architectures to evaluate if the NFs learn the two moons distribution correctly ?

EiffL commented 2 years ago

yep, at stage 0, but we don't have to do all the metrics at once, we can start with one and try to make a plot.

The plot we want is number of sims on the x axis, quality metric on the y axis. And on this plot, you can have several lines, one per model/architecture/training loss used.

We'll ultimately want to have several of these plots, one per distribution we want to learn. We can start with 2 moons, but then we'll want to use some cosmology example.

Justinezgh commented 2 years ago

I tried to implement the dkl but I have negative values and a weird behavior : Unknown

Here is my dkl if you see a mistake I'll be happy to know what it is ^^'

model_NF = hk.without_apply_rng(hk.transform(lambda x : Flow()().log_prob(x)))
model_sample = hk.without_apply_rng(hk.transform(lambda nb_sample: Flow()().sample(nb_sample, seed=next(rng_seq))))

def kl_divergence_fn(params, distribution_b, precision):      
  z = model_sample.apply(params, precision)        
  return jnp.mean(model_NF.apply(params,z) - distribution_b.log_prob(z))
EiffL commented 2 years ago

ouuuuuh that's really cool :-D

On the x axis, you have the number of samples you use to evaluate the metric, right? So what we are seeing is that it needs 100000 samples to converge to a given value?

Justinezgh commented 2 years ago

hmm maybe I'm confused but I think that's not cool ^^ Yes the x axis is the nb of simulations and the y axis is the kl divergence (which is supposed to be close to zero if the approximation is good right ?) Plus, this is the contour plot for the nll loss :

nll-2

and the one for the score/2000 + nll loss : Unknown-3

so we see that the nll loss gets better results faster but I think that this is not what we see with the dkl plot

Justinezgh commented 2 years ago

This is the notebook : link

I also have the one for the mmd metric : Unknown

Justinezgh commented 2 years ago

NF log_prob >>> two moons log_prob, and I think the pb come from the 'tfd.Independent' in the output of the NF

https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Independent

image

yup :/ image

EiffL commented 2 years ago

hummmm yeah..... but that's what tfd.independent is supposed to do no? it sums over the dimensions that are not batched

Justinezgh commented 2 years ago

hmm yes right I don't know where the pb is

EiffL commented 2 years ago

but so, what kind of values do you have for the flow log prob values? I'm surprised if they are very big, because the NLL loss if I remember correctly gives you like -1.5 at most, that would be the average log prob for samples drawn from the two moons

Justinezgh commented 2 years ago

yes it was because I did that

model_NF.apply(params, batch)

instead of

jax.vmap((lambda x : model_NF.apply(params, x.reshape([1,2])).squeeze()))(batch)

still a bit lost with Jax :)

EiffL commented 2 years ago

image Lol, just got this ^^^^^

EiffL commented 2 years ago

yeah... so... something wrong in the way our code for the flow handles the batch dimension :-| ya..., not great. but at least if we vmap it works ^^

EiffL commented 2 years ago

it might be a problem in how the coupling layer is implemented, idk

Justinezgh commented 2 years ago

batch_size = [50,70,90,110]

Unknown

Unknown-2

Unknown-3

Unknown-4

Just a test, I'll try with smaller batch size (but sometimes when the bath size is small (like 40) I get nan in the loss)

notebook

EiffL commented 2 years ago

Did you get better results without the outlier at N=70?

Justinezgh commented 2 years ago

Hmm not really, this is the best I have for now :

Loss = nll nll

Loss = nll + score/1000 nllscore

Metrics : c2st dkl nlpmmd

I changed the variance of the normal distribution for the latent variable to avoid nan in the loss (was set to 0.1):

      nvp = tfd.TransformedDistribution(
            tfd.Independent(tfd.TruncatedNormal(0.5*jnp.ones(d), 
                                                0.05*jnp.ones(d), 
                                                0.01,0.99),
                            reinterpreted_batch_ndims=1),
            bijector=chain)

and I added randomness during training :

    batch, score_batch = get_batch(jax.random.PRNGKey(0))
    for step in tqdm(range(5000)):
      inds = np.random.randint(0, batch_size, int(batch_size / 2))
      l, params, opt_state = update(params, opt_state, batch[inds], score_batch[inds])
Justinezgh commented 2 years ago

loss = nll image loss = nll + score/1000 image image