sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

Refactoring `run` for additional flexibility #40

Open psteinb opened 2 years ago

psteinb commented 2 years ago

Not sure I am overseeing something, but the run methods in the algorithms only return the predicted samples - nothing else.

It might be worthwhile to consider refactoring this, so that each python module in the algorithms directory offers to return the obtained posterior. This would entail in pseudo code:

def train(...):
    return trained_objects

def infer(...)
    return predicted_objects

def run(...):
    trained_objects = train(...)
    predicted_objects = infer(trained_objects, ...)
    return predicted_objects

This refactoring should/would not change the API which is used downstream. It would however allow more analyses on the obtained posterior (mean/median map estimation versus SGD based map estimation etc).

jan-matthis commented 2 years ago

Thanks for the suggestion, Peter!

I agree that for additional analyses, it might be useful to be able to access the "trained object" (currently, depending on the algorithm, there is one trained object, or none, e.g., for some ABC methods). run does return not only predictive samples, but also the number of simulations that were actually performed (simulator.num_simulations; as a safety-check), and, if the algorithm supports it, the log probability of true parameters as an optional third return.

If you are up for it, I'd be glad to accept a PR that refactors the functions along the lines you propose and continue discussing there :)