sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

'MCMCPosterior' object has no attribute 'copy_hyperparameters_from' #53

Closed gisilvs closed 1 year ago

gisilvs commented 1 year ago

When trying to run snle as in the documentation example, the code runs into an error when calling posterior.copy_hyperparameters_from(posteriors[-1]) in snle.py, as it seems like MCMCPosterir does not have any function called copy_hyperparameters_from. What should be the correct behaviour? Would it make sense to just remove the line, or does this require a proper fix?

Here is my code to reproduce the error:

import sbibm

task = sbibm.get_task("two_moons")  # See sbibm.get_available_tasks() for all tasks
prior = task.get_prior()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1)  # 10 per task

# These objects can then be used for custom inference algorithms, e.g.
# we might want to generate simulations by sampling from prior:
thetas = prior(num_samples=10_000)
xs = simulator(thetas)

# Alternatively, we can import existing algorithms, e.g:
from sbibm.algorithms import snle  # See help(rej_abc) for keywords
posterior_samples, _, _ = snle(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000, neural_net='maf')

# Once we got samples from an approximate posterior, compare them to the reference:
from sbibm.metrics import c2st
reference_samples = task.get_reference_posterior_samples(num_observation=1)
c2st_accuracy = c2st(reference_samples, posterior_samples)

# Visualise both posteriors:
from sbibm.visualisation import fig_posterior
fig = fig_posterior(task_name="two_moons", observation=1, samples=[posterior_samples])
# Note: Use fig.show() or fig.save() to show or save the figure

# Get results from other algorithms for comparison:
from sbibm.visualisation import fig_metric
results_df = sbibm.get_results(dataset="main_paper.csv")
fig = fig_metric(results_df.query("task == 'two_moons'"), metric="C2ST")
janfb commented 1 year ago

Indeed, this method is deprecated in the sbi package since v0.18.0. It seems we missed that line when we updated the sbibm run function to the new sbi API.

I hope to find time later this week to make a fixing PR.