sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

Alignment with SBI ABC API #20

Closed psteinb closed 2 years ago

psteinb commented 2 years ago

when running the sbibm demo code based on commit 074e06a, I get

import sbibm

task = sbibm.get_task("two_moons")  # See sbibm.get_available_tasks() for all tasks
prior = task.get_prior()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1)  # 10 per task

# These objects can then be used for custom inference algorithms, e.g.
# we might want to generate simulations by sampling from prior:
thetas = prior(num_samples=10_000)
xs = simulator(thetas)

# Alternatively, we can import existing algorithms, e.g:
from sbibm.algorithms import rej_abc  # See help(rej_abc) for keywords
posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000)

I get

task = <sbibm.tasks.two_moons.task.TwoMoons object at 0x7ff456f40f10>, num_samples = 50, num_simulations = 500, num_observation = 1
observation = tensor([[-0.6397,  0.1623]]), num_top_samples = 100, quantile = 0.2, eps = None, distance = 'l2', batch_size = 1000, save_distances = False
kde_bandwidth = 'cv', sass = False, sass_fraction = 0.5, sass_feature_expansion_degree = 3, lra = False

    def run(
        task: Task,
        num_samples: int,
        num_simulations: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        num_top_samples: Optional[int] = 100,
        quantile: Optional[float] = None,
        eps: Optional[float] = None,
        distance: str = "l2",
        batch_size: int = 1000,
        save_distances: bool = False,
        kde_bandwidth: Optional[str] = "cv",
        sass: bool = False,
        sass_fraction: float = 0.5,
        sass_feature_expansion_degree: int = 3,
        lra: bool = False,
    ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
        """Runs REJ-ABC from `sbi`

        Choose one of `num_top_samples`, `quantile`, `eps`.

        Args:
            task: Task instance
            num_samples: Number of samples to generate from posterior
            num_simulations: Simulation budget
            num_observation: Observation number to load, alternative to `observation`
            observation: Observation, alternative to `num_observation`
            num_top_samples: If given, will use `top=True` with num_top_samples
            quantile: Quantile to use
            eps: Epsilon threshold to use
            distance: Distance to use
            batch_size: Batch size for simulator
            save_distances: If True, stores distances of samples to disk
            kde_bandwidth: If not None, will resample using KDE when necessary, set
                e.g. to "cv" for cross-validated bandwidth selection
            sass: If True, summary statistics are learned as in
                Fearnhead & Prangle 2012.
            sass_fraction: Fraction of simulation budget to use for sass.
            sass_feature_expansion_degree: Degree of polynomial expansion of the summary
                statistics.
            lra: If True, posterior samples are adjusted with
                linear regression as in Beaumont et al. 2002.
        Returns:
            Samples from posterior, number of simulator calls, log probability of true params if computable
        """
        assert not (num_observation is None and observation is None)
        assert not (num_observation is not None and observation is not None)

        assert not (num_top_samples is None and quantile is None and eps is None)

        log = sbibm.get_logger(__name__)
        log.info(f"Running REJ-ABC")

        prior = task.get_prior_dist()
        simulator = task.get_simulator(max_calls=num_simulations)
        if observation is None:
            observation = task.get_observation(num_observation)

        if num_top_samples is not None and quantile is None:
            if sass:
                quantile = num_top_samples / (
                    num_simulations - int(sass_fraction * num_simulations)
                )
            else:
                quantile = num_top_samples / num_simulations

        inference_method = MCABC(
            simulator=simulator,
            prior=prior,
            simulation_batch_size=batch_size,
            distance=distance,
            show_progress_bars=True,
        )
>       posterior, distances = inference_method(
            x_o=observation,
            num_simulations=num_simulations,
            eps=eps,
            quantile=quantile,
            return_distances=True,
            lra=lra,
            sass=sass,
            sass_expansion_degree=sass_feature_expansion_degree,
            sass_fraction=sass_fraction,
        )
E       TypeError: __call__() got an unexpected keyword argument 'return_distances'
jan-matthis commented 2 years ago

Thanks for reporting this!

@janfb: It seems this the APIs diverged with this commit in sbi: https://github.com/mackelab/sbi/commit/9ed18ca1ad9ec9d93b1f09f473157ed4b5ca672c?

psteinb commented 2 years ago

I am hence wondering, if setup.py should honor a specific sbi version then.