sampling without inferring the prior transform #713

Patrickens commented 2 years ago

Dear SBI team,

Im absolutely loving the package! Very well written and documented.

I need to fit a distribution that has convex polytope support, meaning that all samples theta need to satisfy the following: A @ theta <= b. I made the following support class for my prior:

from torch.distributions.constraints import Constraint, _Dependent

class _CannonicalPolytopeSupport(_Dependent):
    def __init__(
            A: torch.Tensor,
            b: torch.Tensor,
            validation_tol = 1e-13,
        self._A = A 
        self._b = b + validation_tol 
        super().__init__(is_discrete=False, event_dim=self._A.shape[1])

    def check(self, value):
        vape    = value.shape
        viewlue = value.permute((-1, *range(value.ndim -1))).view(vape[-1], vape[:-1].numel())
        valid   = (self._A @ viewlue <= self._b).T
        return valid.view(*vape[:-1], self._A.shape[0])

Now I run into the following issue. Once a round of inference is done and I try to build a posterior, sbi assumes that if you have a bounded and dependent support, that there must be a bijective transform to an unconstrained space:

Is there any way around this issue? I've thought of such a transform (have not implemented it yet, since its a real headache), but for now I would like to just check whether a sample is within the support without transforming it to unconstrained space.

Thank you for you efforts!

Patrickens commented 2 years ago

Perhaps I could just use rejection sampling by .build_posterior(density_estimator, sample_with='rejection'), and then we do not use the theta_transform at all, so we just have to build the potential_fun in posterior_estimator_based_potential. Or perhaps we can pass the enable_transform=False kwarg to mcmc_transform from build_posterior somehow.

michaeldeistler commented 2 years ago

Hi there! Thanks for the compliments, we're happy that you're enjoying the package :)

If you want to turn off the tranfsformation, you should probably use the sampler interface.

Could you try this (I did not try it yet):

from sbi.inference import DirectPosterior

inference = SNPE()
posterior_estimator = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(posterior_estimator, prior)
Patrickens commented 2 years ago

So at the instantiation of DirectPosterior, there is also a call to posterior_estimator_based_potential and thus to mcmc_transform. For now, I inserted a enable_theta_transform=False keyword to the constructor of DirectPosterior which is passed further and this works! Thanks for thinking along.

michaeldeistler commented 2 years ago

Ah, great that you found this fix! I'll leave this issue open because we should probably make this more easy...

janfb commented 2 years ago

thanks @Patrickens , interesting problem!

But a plain identity transform should work on your custom support, right?

I added a theta_transform=None kwarg to DirectPosterior, which is passed on to posterior_estimator_based_potential to either build an identity transform, or construct one from the support. see #714

Could you please paste your custom prior class here for testing?