sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
578 stars 145 forks source link

Refactor abstract classes for custom density estimators #1046

Closed janfb closed 2 months ago

janfb commented 6 months ago

Initiated by @tomMoral's input in #1019 we are planning to give more flexibility to users for defining their custom density estimator, by adding another layer of abstraction -- an Estimator base class.

Here is a draft resulting from a discussion with @michaeldeistler and @manuelgloeckler

indent level show inheritance, class methods in parentheses:

image

see also #1041

janfb commented 6 months ago

will be relevant for #963 and the RatioEstimator as well @bkmi

bkmi commented 6 months ago

@michaeldeistler @janfb @manuelgloeckler @jnsbck

I propose we remove the method loss from this abstraction. The estimator should be able to be defined without assuming a way to train it.

This is an issue for MDN, all ratio estimators, and I suspect it will be an issue for the vector-based estimators as well.

MDN has multiple ways to train it depending on whether it is being incorporated into SNPE_A or SNPE_C. The ratio estimators only change in the way they are trained--NOT in the features of the estimator itself. Similarly, flow matching and score matching will require extremely similar estimators but have different losses.

At the time you instantiate an Estimator, I argue that it should be agnostic to the training algorithm, otherwise why not include this abstraction in the training algorithm itself?

The loss should be at the "inference" level (i.e. class SNRE_A, SNPE_B, etc.), rather than at the estimator level.

What do you all think?

jnsbck commented 6 months ago

I think abstracting model, loss and optimization/training seperately makes a lot of sense. Would require a ton of changes to the code in inference though I think.

bkmi commented 6 months ago

fyi I think the answer to this was to let loss exist in DensityEstimator, but not in the other ones.

janfb commented 2 months ago

This is solved now: We have ConditionalEstimator abstract base class in neural_nets/density_estimators that takes care of shapes and requires children to implement loss, log_prob and sample. ConditionalDensityEstimator is the class for most flows (nflows and zuko). And ConditionalVectorFieldEstimator will be the class for score matching and (maybe) flow matching methods.

The RatioEstimator is separated from that and lives in its own ratio_estimators.py bubble.

The neural_nets module is still kind of a mess I think and should be refactored in the future, see #1190