sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
578 stars 145 forks source link

feat: base conditional estimator class #1151

Closed manuelgloeckler closed 3 months ago

manuelgloeckler commented 5 months ago

What does this implement/fix? Explain your changes

Implements an abstract base class for estimators, which has:

Does this close any currently open issues?

Will close #966. Will replace #1072 .

Considerations:

codecov[bot] commented 5 months ago

Codecov Report

Attention: Patch coverage is 67.85714% with 18 lines in your changes missing coverage. Please review.

Project coverage is 72.90%. Comparing base (1b268b8) to head (483ad93). Report is 5 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1151 +/- ## =========================================== - Coverage 83.08% 72.90% -10.18% =========================================== Files 92 93 +1 Lines 7259 7397 +138 =========================================== - Hits 6031 5393 -638 - Misses 1228 2004 +776 ``` | [Flag](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | `72.90% <67.85%> (-10.18%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [sbi/inference/posteriors/direct\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fdirect_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL2RpcmVjdF9wb3N0ZXJpb3IucHk=) | `98.36% <100.00%> (ø)` | | | [...inference/potentials/likelihood\_based\_potential.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Finference%2Fpotentials%2Flikelihood_based_potential.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3RlbnRpYWxzL2xpa2VsaWhvb2RfYmFzZWRfcG90ZW50aWFsLnB5) | `100.00% <100.00%> (ø)` | | | [.../inference/potentials/posterior\_based\_potential.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Finference%2Fpotentials%2Fposterior_based_potential.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3RlbnRpYWxzL3Bvc3Rlcmlvcl9iYXNlZF9wb3RlbnRpYWwucHk=) | `97.05% <100.00%> (ø)` | | | [sbi/inference/snle/snle\_base.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Finference%2Fsnle%2Fsnle_base.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9zbmxlL3NubGVfYmFzZS5weQ==) | `93.61% <100.00%> (ø)` | | | [sbi/inference/snpe/snpe\_base.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Finference%2Fsnpe%2Fsnpe_base.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9zbnBlL3NucGVfYmFzZS5weQ==) | `89.02% <100.00%> (ø)` | | | [sbi/neural\_nets/\_\_init\_\_.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Fneural_nets%2F__init__.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL19faW5pdF9fLnB5) | `100.00% <100.00%> (ø)` | | | [sbi/neural\_nets/density\_estimators/\_\_init\_\_.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2F__init__.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9fX2luaXRfXy5weQ==) | `100.00% <100.00%> (ø)` | | | [.../neural\_nets/density\_estimators/categorical\_net.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fcategorical_net.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9jYXRlZ29yaWNhbF9uZXQucHk=) | `98.03% <100.00%> (ø)` | | | [...nets/density\_estimators/mixed\_density\_estimator.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fmixed_density_estimator.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9taXhlZF9kZW5zaXR5X2VzdGltYXRvci5weQ==) | `69.11% <100.00%> (ø)` | | | [sbi/neural\_nets/density\_estimators/nflows\_flow.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fnflows_flow.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9uZmxvd3NfZmxvdy5weQ==) | `63.46% <100.00%> (ø)` | | | ... and [3 more](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | | ... and [25 files with indirect coverage changes](https://app.codecov.io/gh/sbi-dev/sbi/pull/1151/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev)
manuelgloeckler commented 3 months ago

Great, thanks for the review and suggested changes :)

I think it actually also would be more appropriate to adapt the naming to:

What do you think?