sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

Adding methods to prior for compatability with sbi package #44

Closed ntolley closed 2 years ago

ntolley commented 2 years ago

I've noticed that the prior object from task.get_prior() is not immediately usable with the sbi package since there are no .sample() or .log_prob() methods. Specifically attempting something like this fails:

prior = task.get_prior()
inference = sbi.inference.SNPE(prior=prior, ...)

Looking at the code, I imagine this could be implemented by having task.get_prior() return a class instead of a function. Then the class could have a __call__() method to maintain compatibility with the current API. Happy to give this a shot if you guys agree with the change.

Edit: it would actually just suffice to expose the prior_dist: https://github.com/sbi-benchmark/sbibm/blob/15f068a08a938383116ffd92b92de50c580810a3/sbibm/tasks/slcp/task.py#L60

jan-matthis commented 2 years ago

Thanks for the question -- you should e.g. be able to access it via task.get_prior_dist(). Hope this helps!

ntolley commented 2 years ago

Oh that's perfect! Thanks for the quick response.