Closed pscicluna closed 1 year ago
Quick question first: do you use multi-round SNPE or single-round?
Michael
Anyways: I recommend using the sampler interface (see the end of the tutorial). In that interface, you can explicitly turn off all transformations that are built using the prior like this:
from sbi.inference import SNPE, DirectPosterior
inference = SNPE(prior=prior)
net = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior, enable_transform=False) # <- `False` turns off transforms
Turning off transformations will not have any impact on the obtained posterior itself. It will only impact the behaviour of the .map()
function (optimization in transformed vs non-transformed space) but I don't think it will have a big impact.
Does this resolve your issues?
Thanks, in initial tests this does appear to have solved the problem! I will close the issue for now and re-open it if any further issues show up.
For background, this was happening with single-round inference. As it was single-round, I was trying to make sure bugs are gone before trying to expose multi-round inference.
Performance-wise, this doesn't seem have made any difference. I just had to update a few packages to make it work properly.
Thanks again!
Hi! I'm having some issues which seem to be related to issue https://github.com/mackelab/sbi/issues/716. I'm trying to do SNPE with a range of different models defined by downstream users.
Because I need this to be relatively general, I'm defining a custom prior, based on a user-defined prior transform (generate samples) and log_prob - users have a lot of different requirements so it is hard to force them into building a prior from torch Distributions.
Sometimes I get the assertion error that there is a mismatch when building the posterior. I have yet to understand what exactly is causing this, as the behaviour occurs for some priors/models and not others. Some priors work fine when no bounds are provided, and then complain when I provide bounds, and some fail when no bounds are provided.
At this point, I'm not sure where to begin in debugging and avoiding this issue, so any insights or suggestions would be most welcome. Are there typical causes of this error, or particularly pathological distributions that require special attention, or is it more likely that I have a bug somewhere? Are there particular details that would be useful?