Closed manuelgloeckler closed 2 months ago
Describe the bug Given a batched observation, i.e., x1 and x2, the sampling method mixes up samples from different distributions.
To Reproduce
import torch from sbi import analysis as analysis from sbi import utils as utils from sbi.inference.base import infer num_dim = 3 prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim)) def simulator(parameter_set): return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1 posterior = infer(simulator, prior, method="SNPE", num_simulations=200) observation = torch.stack([torch.zeros(3), torch.ones(3)]) posterior_samples = posterior.posterior_estimator.sample((1000,), condition=observation) # Outputs an multimodal distribution, but should be unimodal (mixes up samples from the two different x_os) samples1 = posterior_samples[:,0].detach() _ = analysis.pairplot([samples1], limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6))
Additional context
Likely a "reshaping" bug.
Describe the bug Given a batched observation, i.e., x1 and x2, the sampling method mixes up samples from different distributions.
To Reproduce
Additional context
Likely a "reshaping" bug.