sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

refactoring to split run methods (*mc, snpe) #42

Open psteinb opened 2 years ago

psteinb commented 2 years ago

I think only splitting off the top part makes most sense for now. Would be lovely to hear your thoughts @jan-matthis.

Closes #40

jan-matthis commented 2 years ago

Thanks for tackling this!

I guess splitting things into two functions (build_posterior and run) rather than three functions (train, infer, run, as you originally proposed) simplified the refactor wrt to passing the right arguments around -- in the PR, you simply pass them all using locals(). Even docstrings are very straightforward now, the only difference between build_posterior and run now being what is returned -- with respect to arguments they are completely equivalent.

Before proceeding, I think it would be good to discuss two alternatives to this refactor, namely:

  1. Simply adding a flags such as save_posterior to the run method which stores the posterior object to disk (similarly for other artifacts that might be of interest)
  2. Equipping run with optional returns that can be requested via a flag

Would be great to hear your thoughts on this!

psteinb commented 2 years ago

My feedback on the two alternatives would be:

  1. as the list of parameters/arguments to run is long already, this appears the least intrusive to the interface (we can set the default value to None and this way all downstream uses do not have to change their code). However, it is introducing a side effect, which is not so convenient in the long run.

  2. Also an option, but it might break all uses of

    v1, v2, _, _ = run(#my args)

    if #my args forces run to return 5 objects (i.e. the posterior in addition), the last _ might have a problem and will throw a ValueError to unpack too many values.

From this consideration, I would propose to go with 1. but to contemplate the split I suggested earlier. In other words:

psteinb commented 2 years ago

I gave this a whirl with https://github.com/sbi-benchmark/sbibm/pull/42/commits/633487fa228013cb92e09e577884a9ecf2e7fc79 with mcabc.py only at the moment. Some observations I made:

bottom line:

Please let me know what you think. I saw at least 5 functions that should be refactored in this fashion.

psteinb commented 2 years ago

Looking into this further, it might make sense to restructure the algorithms.sbi into a ABC class, e.g. Fit, Fitter, ... and then code up specialisations for mcabc, snpe. In a class structure, handling parameters might be a bit easier.