sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
573 stars 144 forks source link

SNLE and SNRE fail with a 1D prior #1283

Open paarth-dudani opened 12 hours ago

paarth-dudani commented 12 hours ago

Describe the bug I am implementing SNRE and SNLE (from implemented algorithms) on a simple exponential simulator model with noise and a prior. The algorithms work just fine for a uniform prior but give the following error: 'number of categories cannot exceed 2^24', with the exponential prior.

To Reproduce

  1. Versions Python version: 3.9.13 SBI version: 0.23.1

  2. Code for SNLE implementation but I get the same error for SNRE implementation as well (with Inference object: NRE)

## Function to introduce noise into the exponential:
def noisy_exponential(scale, input):
    x_noise = np.zeros((len(input)),)
    for i in range(len(input)):
        value = 1/(scale)*np.exp(-input[i]/scale) + 0.05*np.random.normal(0,0.5)
        if value < 0:
            x_noise[i] = 0
        else:
            x_noise[i] = value
    return x_noise
n = 200
t = np.linspace(0,8,n)
# Define the prior
num_dims = 1
num_sims = 100
num_rounds = 2
theta_true = torch.tensor([2.0])

# prior = BoxUniform(low=0*torch.zeros(num_dims), high=4*torch.ones(num_dims))
exp_prior = Exponential(torch.tensor([2.0]))
simulator = lambda theta: ((1/theta)*np.exp(-t/theta) + 0.05*np.random.normal(0,0.5)).float()
x_o = torch.tensor(noisy_exponential(theta_true,t)).float()
# Inference 
inference = NLE(exp_prior)
proposal = exp_prior
for _ in range(num_rounds):
    theta = proposal.sample((num_sims,))
    x = simulator(theta)
    _ = inference.append_simulations(theta, x).train()
    posterior = inference.build_posterior(mcmc_method="slice_np_vectorized",
                                          mcmc_parameters={"num_chains": 2,
                                                           "thin": 5})
    proposal = posterior.set_default_x(x_o)
  1. RuntimeError: Number of categories cannot exceed 2^24. (after the neural network successfully converges)
RuntimeError                              Traceback (most recent call last)
Input [In [43]](vscode-notebook-cell:?execution_count=43), in <cell line: 17>()
     [16](vscode-notebook-cell:?execution_count=43&line=16) proposal = exp_prior
     [17](vscode-notebook-cell:?execution_count=43&line=17) for _ in range(num_rounds):
---> [18](vscode-notebook-cell:?execution_count=43&line=18)     theta = proposal.sample((num_sims,))
     [19](vscode-notebook-cell:?execution_count=43&line=19)     x = simulator(theta)
     [20](vscode-notebook-cell:?execution_count=43&line=20)     _ = inference.append_simulations(theta, x).train()

File ~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:306, in MCMCPosterior.sample(self, sample_shape, x, method, thin, warmup_steps, num_chains, init_strategy, init_strategy_parameters, init_strategy_num_candidates, mcmc_parameters, mcmc_method, sample_with, num_workers, mp_context, show_progress_bars)
    [303](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:303) init_strategy = _maybe_use_dict_entry(init_strategy, "init_strategy", m_p)
    [304](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:304) self.potential_ = self._prepare_potential(method)  # type: ignore
--> [306](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:306) initial_params = self._get_initial_params(
    [307](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:307)     init_strategy,  # type: ignore
    [308](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:308)     num_chains,  # type: ignore
    [309](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:309)     num_workers,
    [310](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:310)     show_progress_bars,
    [311](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:311)     **init_strategy_parameters,
    [312](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:312) )
    [313](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:313) num_samples = torch.Size(sample_shape).numel()
    [315](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:315) track_gradients = method in ("hmc_pyro", "nuts_pyro", "hmc_pymc", "nuts_pymc")

File ~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:618, in MCMCPosterior._get_initial_params(self, init_strategy, num_chains, num_workers, show_progress_bars, **kwargs)
    [615](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:615)     initial_params = torch.cat(initial_params)  # type: ignore
    [616](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/inference/posteriors/mcmc_posterior.py:616) else:
...
--> [112](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/samplers/mcmc/init_strategy.py:112) idxs = torch.multinomial(probs, 1, replacement=False)
    [113](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/samplers/mcmc/init_strategy.py:113) # Return transformed sample.
    [114](https://file+.vscode-resource.vscode-cdn.net/Users/paarthdudani/Desktop/All%20folders/Python%20files/Thesis%20Code/Pilot_code/~/anaconda3/lib/python3.9/site-packages/sbi/samplers/mcmc/init_strategy.py:114) return transform(init_param_candidates[idxs, :])

RuntimeError: number of categories cannot exceed 2^24

Expected behavior I expect the network to undergo multiple rounds of training (2 in this example) or give me a pairplot after one round of training (not shown above).

michaeldeistler commented 12 hours ago

Hi there,

thanks a lot for reporting this! The following will fix it:

class WrappedExponential(Exponential):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def log_prob(*args, **kwargs):
        log_probs = Exponential.log_prob(*args, **kwargs)
        return log_probs.squeeze()

exp_prior = WrappedExponential(torch.tensor([2.0]))

The reason that the issue was happening is that Exponential.log_prob returns a tensor of shape (num_samples, 1), but any multivariate torch distribution (e.g., MultivariateNormal) returns a tensor of shape (num_samples).

I will leave this issue open though because we should deal with this in process_prior, but we currently do not.

I hope that the above fixes the issue! Let me know if you have any more questions!

All the best Michael

michaeldeistler commented 12 hours ago

More notes for future fixing on our side:

This is only an issue for 1D pytorch distributions. The issue is that, e.g., torch.distributions.Exponential and torch.distributions.Normal, or torch.distributions.Uniform all return 1) either no sample dimension and no log-prob dimension, or 2) either a sample dimension and a log prob dimension. However, for sbi, we need a sample dimension but no log-prob dimension.

from torch.distributions import Exponential, MultivariateNormal

prior = Exponential(torch.tensor(2.0))
samples = prior.sample((10,))  # (10,)
log_probs = prior.log_prob(samples)  # (10,)
# `sbi` raises an error because one must have a sample dimension.

prior = Exponential(torch.tensor([2.0]))
samples = prior.sample((10,))  # (10, 1)
log_probs = prior.log_prob(samples)  # (10, 1)
# `sbi` fails because the log_prob dimension contains the data dim.

prior = MultivariateNormal(torch.tensor([2.0]), torch.tensor([2.0]))
samples = prior.sample((10,))  # (10, 1)
log_probs = prior.log_prob(samples)  # (10,)
# `sbi` works.

IMO, the easiest fix would be to introduce the following:

class OneDimDistributionWrapper(torch.Distribution):
    def __init__(self, dist, *args, **kwargs):
        super().__init__()
        self.dist = dist

    def sample(*args, **kwargs):
        return self.dist.sample(*args, **kwargs)

    def log_prob(*args, **kwargs):
        return self.dist.log_prob(*args, **kwargs)[..., 0]  # Remove the additional dimension.

    @property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        return self.dist.arg_constraints

    @property
    def support(self):
        return self.dist.support

    @property
    def mean(self) -> Tensor:
        return self.dist.mean

    @property
    def variance(self) -> Tensor:
        return self.dist.variance

We could then use this to wrap the 1D distributions which have a sample dimension.