sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
546 stars 133 forks source link

Density Estimator batched sample mixes up samples from different posteriors #1154

Closed manuelgloeckler closed 2 months ago

manuelgloeckler commented 2 months ago

Describe the bug Given a batched observation, i.e., x1 and x2, the sampling method mixes up samples from different distributions.

To Reproduce

import torch

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference.base import infer

num_dim = 3
prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))

def simulator(parameter_set):
    return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1

posterior = infer(simulator, prior, method="SNPE", num_simulations=200)
observation = torch.stack([torch.zeros(3), torch.ones(3)])
posterior_samples = posterior.posterior_estimator.sample((1000,), condition=observation)

# Outputs an multimodal distribution, but should be unimodal (mixes up samples from the two different x_os)
samples1 = posterior_samples[:,0].detach()
_ = analysis.pairplot([samples1], limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6))

Additional context

Likely a "reshaping" bug.