sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
546 stars 133 forks source link

DensityEstimator.loss does not take `sample_dim` #1149

Closed michaeldeistler closed 2 months ago

michaeldeistler commented 2 months ago

In #1066, we had defined that log_prob and loss have the same input and output shapes:

density_estimator.log_prob(input, condition)
input: (sample_input, batch_input, *event_shape_input)
condition: (batch_condition, *event_shape_condition)
returns: (sample_input, batch_input)
raises: batch_input != batch_condition

However, for .loss, we are now removing the sample_dim. Therefore, the .loss function now has the following signature:

input: (batch_input, *event_shape_input)
condition: (batch_condition, *event_shape_condition)
returns: (batch_input)
raises: batch_input != batch_condition

Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.

codecov[bot] commented 2 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 77.01%. Comparing base (005aeac) to head (6ae9ede).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1149 +/- ## ========================================== - Coverage 85.09% 77.01% -8.09% ========================================== Files 90 90 Lines 6649 6643 -6 ========================================== - Hits 5658 5116 -542 - Misses 991 1527 +536 ``` | [Flag](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | `77.01% <100.00%> (-8.09%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [sbi/inference/snle/mnle.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Finference%2Fsnle%2Fmnle.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9zbmxlL21ubGUucHk=) | `85.00% <ø> (-8.48%)` | :arrow_down: | | [sbi/inference/snle/snle\_base.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Finference%2Fsnle%2Fsnle_base.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9zbmxlL3NubGVfYmFzZS5weQ==) | `93.61% <100.00%> (ø)` | | | [sbi/inference/snpe/snpe\_base.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Finference%2Fsnpe%2Fsnpe_base.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9zbnBlL3NucGVfYmFzZS5weQ==) | `89.02% <100.00%> (ø)` | | | [sbi/neural\_nets/density\_estimators/base.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fbase.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9iYXNlLnB5) | `57.14% <ø> (ø)` | | | [.../neural\_nets/density\_estimators/categorical\_net.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fcategorical_net.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9jYXRlZ29yaWNhbF9uZXQucHk=) | `98.03% <100.00%> (ø)` | | | [...nets/density\_estimators/mixed\_density\_estimator.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fmixed_density_estimator.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9taXhlZF9kZW5zaXR5X2VzdGltYXRvci5weQ==) | `69.11% <100.00%> (ø)` | | | [sbi/neural\_nets/density\_estimators/nflows\_flow.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fnflows_flow.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9uZmxvd3NfZmxvdy5weQ==) | `62.74% <100.00%> (ø)` | | | [sbi/neural\_nets/density\_estimators/zuko\_flow.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fzuko_flow.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy96dWtvX2Zsb3cucHk=) | `64.44% <100.00%> (ø)` | | ... and [22 files with indirect coverage changes](https://app.codecov.io/gh/sbi-dev/sbi/pull/1149/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev)