sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
566 stars 141 forks source link

sampling without inferring the prior transform #713

Closed Patrickens closed 2 years ago

Patrickens commented 2 years ago

Dear SBI team,

Im absolutely loving the package! Very well written and documented.

I need to fit a distribution that has convex polytope support, meaning that all samples theta need to satisfy the following: A @ theta <= b. I made the following support class for my prior:

from torch.distributions.constraints import Constraint, _Dependent

class _CannonicalPolytopeSupport(_Dependent):
    def __init__(
            self,
            A: torch.Tensor,
            b: torch.Tensor,
            validation_tol = 1e-13,
    ):
        self._A = A 
        self._b = b + validation_tol 
        super().__init__(is_discrete=False, event_dim=self._A.shape[1])

    def check(self, value):
        vape    = value.shape
        viewlue = value.permute((-1, *range(value.ndim -1))).view(vape[-1], vape[:-1].numel())
        valid   = (self._A @ viewlue <= self._b).T
        return valid.view(*vape[:-1], self._A.shape[0])

Now I run into the following issue. Once a round of inference is done and I try to build a posterior, sbi assumes that if you have a bounded and dependent support, that there must be a bijective transform to an unconstrained space:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [32], in <module>
----> 1 posterior = snpec.build_posterior(density_estimator)

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\inference\snpe\snpe_base.py:420, in PosteriorEstimator.build_posterior(self, density_estimator, prior, sample_with, mcmc_method, vi_method, mcmc_parameters, vi_parameters, rejection_sampling_parameters)
    417     # Otherwise, infer it from the device of the net parameters.
    418     device = next(density_estimator.parameters()).device.type
--> 420 potential_fn, theta_transform = posterior_estimator_based_potential(
    421     posterior_estimator=posterior_estimator, prior=prior, x_o=None
    422 )
    424 if sample_with == "rejection":
    425     if "proposal" in rejection_sampling_parameters.keys():

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\inference\potentials\posterior_based_potential.py:46, in posterior_estimator_based_potential(posterior_estimator, prior, x_o)
     41 device = str(next(posterior_estimator.parameters()).device)
     43 potential_fn = PosteriorBasedPotential(
     44     posterior_estimator, prior, x_o, device=device
     45 )
---> 46 theta_transform = mcmc_transform(prior, device=device)
     48 return potential_fn, theta_transform

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\utils\sbiutils.py:615, in mcmc_transform(prior, num_prior_samples_for_zscoring, enable_transform, device, **kwargs)
    613 # Prior with bounded support, e.g., uniform priors.
    614 if has_support and support_is_bounded:
--> 615     transform = biject_to(prior.support)
    616 # For all other cases build affine transform with mean and std.
    617 else:
    618     if hasattr(prior, "mean") and hasattr(prior, "stddev"):

File C:\Miniconda3\envs\pta\lib\site-packages\torch\distributions\constraint_registry.py:142, in ConstraintRegistry.__call__(self, constraint)
    140     factory = self._registry[type(constraint)]
    141 except KeyError:
--> 142     raise NotImplementedError(
    143         f'Cannot transform {type(constraint).__name__} constraints') from None
    144 return factory(constraint)

NotImplementedError: Cannot transform _CannonicalPolytopeSupport constraints

Is there any way around this issue? I've thought of such a transform (have not implemented it yet, since its a real headache), but for now I would like to just check whether a sample is within the support without transforming it to unconstrained space.

Thank you for you efforts!

Patrickens commented 2 years ago

Perhaps I could just use rejection sampling by .build_posterior(density_estimator, sample_with='rejection'), and then we do not use the theta_transform at all, so we just have to build the potential_fun in posterior_estimator_based_potential. Or perhaps we can pass the enable_transform=False kwarg to mcmc_transform from build_posterior somehow.

michaeldeistler commented 2 years ago

Hi there! Thanks for the compliments, we're happy that you're enjoying the package :)

If you want to turn off the tranfsformation, you should probably use the sampler interface.

Could you try this (I did not try it yet):

from sbi.inference import DirectPosterior

inference = SNPE()
posterior_estimator = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(posterior_estimator, prior)
Patrickens commented 2 years ago

So at the instantiation of DirectPosterior, there is also a call to posterior_estimator_based_potential and thus to mcmc_transform. For now, I inserted a enable_theta_transform=False keyword to the constructor of DirectPosterior which is passed further and this works! Thanks for thinking along.

michaeldeistler commented 2 years ago

Ah, great that you found this fix! I'll leave this issue open because we should probably make this more easy...

janfb commented 2 years ago

thanks @Patrickens , interesting problem!

But a plain identity transform should work on your custom support, right?

I added a theta_transform=None kwarg to DirectPosterior, which is passed on to posterior_estimator_based_potential to either build an identity transform, or construct one from the support. see #714

Could you please paste your custom prior class here for testing?