Open paarth-dudani opened 12 hours ago
Hi there,
thanks a lot for reporting this! The following will fix it:
class WrappedExponential(Exponential):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def log_prob(*args, **kwargs):
log_probs = Exponential.log_prob(*args, **kwargs)
return log_probs.squeeze()
exp_prior = WrappedExponential(torch.tensor([2.0]))
The reason that the issue was happening is that Exponential.log_prob
returns a tensor of shape (num_samples, 1)
, but any multivariate torch distribution (e.g., MultivariateNormal
) returns a tensor of shape (num_samples)
.
I will leave this issue open though because we should deal with this in process_prior
, but we currently do not.
I hope that the above fixes the issue! Let me know if you have any more questions!
All the best Michael
More notes for future fixing on our side:
This is only an issue for 1D pytorch distributions. The issue is that, e.g., torch.distributions.Exponential
and torch.distributions.Normal
, or torch.distributions.Uniform
all return 1) either no sample dimension and no log-prob dimension, or 2) either a sample dimension and a log prob dimension. However, for sbi
, we need a sample dimension but no log-prob dimension.
from torch.distributions import Exponential, MultivariateNormal
prior = Exponential(torch.tensor(2.0))
samples = prior.sample((10,)) # (10,)
log_probs = prior.log_prob(samples) # (10,)
# `sbi` raises an error because one must have a sample dimension.
prior = Exponential(torch.tensor([2.0]))
samples = prior.sample((10,)) # (10, 1)
log_probs = prior.log_prob(samples) # (10, 1)
# `sbi` fails because the log_prob dimension contains the data dim.
prior = MultivariateNormal(torch.tensor([2.0]), torch.tensor([2.0]))
samples = prior.sample((10,)) # (10, 1)
log_probs = prior.log_prob(samples) # (10,)
# `sbi` works.
IMO, the easiest fix would be to introduce the following:
class OneDimDistributionWrapper(torch.Distribution):
def __init__(self, dist, *args, **kwargs):
super().__init__()
self.dist = dist
def sample(*args, **kwargs):
return self.dist.sample(*args, **kwargs)
def log_prob(*args, **kwargs):
return self.dist.log_prob(*args, **kwargs)[..., 0] # Remove the additional dimension.
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.dist.arg_constraints
@property
def support(self):
return self.dist.support
@property
def mean(self) -> Tensor:
return self.dist.mean
@property
def variance(self) -> Tensor:
return self.dist.variance
We could then use this to wrap the 1D distributions which have a sample dimension.
Describe the bug I am implementing SNRE and SNLE (from implemented algorithms) on a simple exponential simulator model with noise and a prior. The algorithms work just fine for a uniform prior but give the following error: 'number of categories cannot exceed 2^24', with the exponential prior.
To Reproduce
Versions Python version: 3.9.13 SBI version: 0.23.1
Code for SNLE implementation but I get the same error for SNRE implementation as well (with Inference object: NRE)
Expected behavior I expect the network to undergo multiple rounds of training (2 in this example) or give me a pairplot after one round of training (not shown above).