stefanradev93 / BayesFlow

A Python library for amortized Bayesian workflows using generative neural networks.
https://bayesflow.org/
MIT License
286 stars 45 forks source link

Diagnostic plots do not do so well with simple (one-parameter) models #135

Closed Kucharssim closed 5 months ago

Kucharssim commented 5 months ago

Diagostic plots fail when the model has only one parameter, see the simple example below. I can make a PR to fix this myself, but it seems I don't have permissions to assign myself to this issue as instructed in CONTIBUTING

import bayesflow as bf
import numpy as np
import matplotlib.pyplot as plt

RNG = np.random.default_rng(314159)

model = bf.simulation.GenerativeModel(
    prior = bf.simulation.Prior(
        prior_fun = lambda: np.r_[RNG.beta(a=1, b=1)]
    ),
    simulator = bf.simulation.Simulator(
        simulator_fun = lambda theta: np.r_[RNG.binomial(n=10, p=theta)]
    )
)

amortizer = bf.amortizers.AmortizedPosterior(
    inference_net = bf.inference_networks.InvertibleNetwork(num_params=1)
)

trainer = bf.trainers.Trainer(
    amortizer = amortizer,
    generative_model = model,
    configurator = lambda forward_dict: { "parameters": forward_dict["prior_draws"].astype(np.float32), "direct_conditions": forward_dict["sim_data"].astype(np.float32) }
)

sims = trainer.configurator(model(100))
posterior_samples = amortizer.sample(sims, n_samples=500)

print(sims["parameters"].shape)
print(posterior_samples.shape)

try:
    f = bf.diagnostics.plot_sbc_histograms(posterior_samples, sims["parameters"], num_bins=5)
except:
    print("Could not plot sbc histogram")

try:
    f = bf.diagnostics.plot_sbc_ecdf(posterior_samples, sims["parameters"], difference=True, stacked=False)
except:
    print("Could not plot sbc ecdf plot")

try:
    f = bf.diagnostics.plot_recovery(posterior_samples, sims["parameters"])
except:
    print("Could not plot recovery plot")

try:
    f = bf.diagnostics.plot_z_score_contraction(posterior_samples, sims["parameters"])
except:
    print("Could not plot posterior contraction plot")

plt.show()
stefanradev93 commented 5 months ago

This may be a nice-to-have feature. Even though one-parameter models are rare (and can usually be tackled with much easier methods than SBI), they do pop up now and then...

Kucharssim commented 5 months ago

Of course! Note that the rest of the code seems to work with such models, it's just the plotting functions that do not. Even though it's a little bit silly to use SBI for such an example it would be nice to run it without errors 😄