sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
571 stars 143 forks source link

Mismatch between true and re-transformed parameters with custom prior #791

Closed pscicluna closed 1 year ago

pscicluna commented 1 year ago

Hi! I'm having some issues which seem to be related to issue https://github.com/mackelab/sbi/issues/716. I'm trying to do SNPE with a range of different models defined by downstream users.

Because I need this to be relatively general, I'm defining a custom prior, based on a user-defined prior transform (generate samples) and log_prob - users have a lot of different requirements so it is hard to force them into building a prior from torch Distributions.

Sometimes I get the assertion error that there is a mismatch when building the posterior. I have yet to understand what exactly is causing this, as the behaviour occurs for some priors/models and not others. Some priors work fine when no bounds are provided, and then complain when I provide bounds, and some fail when no bounds are provided.

At this point, I'm not sure where to begin in debugging and avoiding this issue, so any insights or suggestions would be most welcome. Are there typical causes of this error, or particularly pathological distributions that require special attention, or is it more likely that I have a bug somewhere? Are there particular details that would be useful?

michaeldeistler commented 1 year ago

Quick question first: do you use multi-round SNPE or single-round?

Michael

michaeldeistler commented 1 year ago

Anyways: I recommend using the sampler interface (see the end of the tutorial). In that interface, you can explicitly turn off all transformations that are built using the prior like this:

from sbi.inference import SNPE, DirectPosterior

inference = SNPE(prior=prior)
net = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior, enable_transform=False)  # <- `False` turns off transforms

Turning off transformations will not have any impact on the obtained posterior itself. It will only impact the behaviour of the .map() function (optimization in transformed vs non-transformed space) but I don't think it will have a big impact.

Does this resolve your issues?

pscicluna commented 1 year ago

Thanks, in initial tests this does appear to have solved the problem! I will close the issue for now and re-open it if any further issues show up.

For background, this was happening with single-round inference. As it was single-round, I was trying to make sure bugs are gone before trying to expose multi-round inference.

Performance-wise, this doesn't seem have made any difference. I just had to update a few packages to make it work properly.

Thanks again!