sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
571 stars 143 forks source link

Interchangeability of `Callable` and `BasePotential` #1223

Open schroedk opened 3 weeks ago

schroedk commented 3 weeks ago

I have a question regarding the interchangeability of the argument potential_fn of

class NeuralPosterior(ABC):
    r"""Posterior $p(\theta|x)$ with `log_prob()` and `sample()` methods.<br/><br/>
    All inference methods in sbi train a neural network which is then used to obtain
    the posterior distribution. The `NeuralPosterior` class wraps the trained network
    such that one can directly evaluate the (unnormalized) log probability and draw
    samples from the posterior.
    """

    def __init__(
        self,
        potential_fn: Union[Callable, BasePotential],

For the callable case, it must be something like:

def potential(theta=None, x0=None)
    ...

in contrast to BasePotential, which is a Callable with theta as positional argument and track_gradients as keyword argument, correct? Is this tested somewhere? I only found examples where the argument is of type BasePotential.

schroedk commented 3 weeks ago

Related #1055

janfb commented 3 weeks ago

yes, correct.

The reason that the custom potential_fn has theta and x_o as args is that quantities are required to calculate the "potential", i.e., the unnormalized posterior probability.

For the BasePotential potential, the call method does not have x_o as arg, because it is set as property at runtime.

If a user passes a custom potential, then this is checked for the required args here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/sbi/inference/posteriors/base_posterior.py#L57-L69 and then wrapped as BasePotential here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/sbi/inference/potentials/base_potential.py#L80-L97

janfb commented 3 weeks ago

Is this tested somewhere? I only found examples where the argument is of type BasePotential.

Yes, I had to dig a bit as well, but it's tested here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/tests/potential_test.py#L28-L37

Here, you can how we define a custom potential, depending on inputs theta and x_o.

michaeldeistler commented 2 weeks ago

I think this can be closed, feel free to reopen if anything is still unclear!

janfb commented 2 weeks ago

I think it's a good starting point for refactoring the Callable potential API.

michaeldeistler commented 2 weeks ago

what what have to be done here? Just more docs?

janfb commented 2 weeks ago

At the moment, if a user passes a just a Callable as potential, we test during runtime whether it has the required arguments, e.g., theta and track_gradients and x_o (or so). This is brittle. It would be nice to do this beforehand with types, e.g., define a Protocol to ensure that the passed Callable has the correct signature.