sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
589 stars 151 forks source link

cond_coeff_mat - normalisation issue #403

Closed jnsbck closed 3 years ago

jnsbck commented 3 years ago

Under some circumstances, it seems, the conditional correlation matrix is not being normalised properly. (...or there is something else that I am missing entirely.)

Minimal code example to reproduce the issue:

import numpy as np
import torch
import matplotlib.pyplot as plt

# sbi
import sbi.utils as utils
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi
from sbi.utils.get_nn_models import posterior_nn
from sbi.utils import pairplot, conditional_pairplot, conditional_corrcoeff, eval_conditional_density

# setting random seed
np.random.seed(1)

def simulator_placeholder(params):
    """Noisy Identity.
    Takes a vector of 2 parameters and adds noise."""

    return np.random.multivariate_normal(np.array(params),np.eye(2))

# mock observation
x_o = np.array([ 10., 50.])

# pre-simulating data
prior_min = [0.1, 0.001]
prior_max = [100., 100.]
prior = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), 
                                    high=torch.as_tensor(prior_max))

prior_sample = prior.sample((2000,))

thetas = []
outputs = []
for params in prior_sample:
    result = simulator_placeholder(params)
    outputs.append(result)
    thetas.append(params)

thetas = torch.stack(thetas, dim=0)
outputs = np.array(outputs)

# setting up sbi
simulator, prior = prepare_for_sbi(simulator_placeholder, prior)
density_estimator_build_fun = posterior_nn(model='mdn')

inference = SNPE(simulator, prior, density_estimator=density_estimator_build_fun, 
                 show_progress_bars=True, num_workers=2)

inference.provide_presimulated(torch.as_tensor(thetas, dtype=torch.float32), \
                               torch.as_tensor(outputs, dtype=torch.float32), from_round=0)

# running the inference                               
proposal = None
posterior = inference(num_simulations=0, proposal=proposal)

# sampling a condition
posterior.set_default_x(x_o)
condition = posterior.sample((1,))

# compute conditional correlations
cond_coeff_mat = conditional_corrcoeff(
    density=posterior,
    condition=condition,
    limits=torch.tensor(list(zip(prior_min,prior_max))))
fig, ax = plt.subplots(1,1, figsize=(4,4))
im = plt.imshow(cond_coeff_mat, cmap='PiYG') # without clim=[-1,1]
_ = fig.colorbar(im)
plt.show()

print(cond_coeff_mat)

The output this produces is similar to the following:

Figure_1

tensor([[ 1.0000, -2.8905],
        [-2.8905,  1.0000]])

@ybernaerts FYI

jnsbck commented 3 years ago

Thanks a ton for the quick fix! Seems to do the trick :)

michaeldeistler commented 3 years ago

Hi Jonas!

Thanks for having reported this issue. As you have already seen, I have opened a PR #404. However, I had previously committed a "solution" in this PR which was not correct. According to the time at which you closed this issue, you are likely using this wrong "solution". I fixed this now. Please have another look at #404 and correct your code accordingly.

Again, thanks for reporting! Michael

jnsbck commented 3 years ago

Hey Michael,

I noticed something was off a few days ago too, but wasn't able to reproduce it in the stripped down example that I provided so I assumed it was something wrong with my implementation. Glad to know its solved now. Thanks for the heads-up and fix!

Best,

Jonas

michaeldeistler commented 3 years ago

Cool! Feel free to re-open the issue if the current fix still behaves weirdly for you :)

Michael

jnsbck commented 3 years ago

Ran my posteriors through the function, output seems a lot more sensible now. :)

Just noticed that this means the tutorial on your website also produces a different output now, in case you want to update it accordingly. :)

New: image vs. Old: image

michaeldeistler commented 3 years ago

Hi Jonas,

sorry for getting back to this so late. When running the tutorial notebook, my conditional correlation matrix does not change. It still looks like the "old" matrix (the second one you showed). Are you sure that your code is aligned with the current main branch?

Michael

jnsbck commented 3 years ago

Hi Michael,

indeed, my code seemed to be misaligned with main. I just checked and the output matches the old one now. :)

Jonas