Open schroedk opened 3 weeks ago
Related #1055
yes, correct.
The reason that the custom potential_fn
has theta
and x_o
as args is that quantities are required to calculate the "potential", i.e., the unnormalized posterior probability.
For the BasePotential
potential, the call method does not have x_o
as arg, because it is set as property
at runtime.
If a user passes a custom potential, then this is checked for the required args here:
https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/sbi/inference/posteriors/base_posterior.py#L57-L69
and then wrapped as BasePotential
here:
Is this tested somewhere? I only found examples where the argument is of type
BasePotential
.
Yes, I had to dig a bit as well, but it's tested here:
Here, you can how we define a custom potential, depending on inputs theta
and x_o
.
I think this can be closed, feel free to reopen if anything is still unclear!
I think it's a good starting point for refactoring the Callable
potential API.
what what have to be done here? Just more docs?
At the moment, if a user passes a just a Callable
as potential, we test during runtime whether it has the required arguments, e.g., theta
and track_gradients
and x_o
(or so). This is brittle. It would be nice to do this beforehand with types, e.g., define a Protocol
to ensure that the passed Callable
has the correct signature.
I have a question regarding the interchangeability of the argument
potential_fn
ofFor the callable case, it must be something like:
in contrast to
BasePotential
, which is aCallable
withtheta
as positional argument andtrack_gradients
as keyword argument, correct? Is this tested somewhere? I only found examples where the argument is of typeBasePotential
.