sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
546 stars 133 forks source link

feat: batched sampling for MCMC #1176

Open manuelgloeckler opened 2 weeks ago

manuelgloeckler commented 2 weeks ago

What does this implement/fix? Explain your changes

This pull request aims to implement the sample_batched method for MCMC.

Current problem

The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.

codecov[bot] commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 85.26316% with 14 lines in your changes missing coverage. Please review.

Project coverage is 73.03%. Comparing base (337f072) to head (c55e6e4).

:exclamation: Current head c55e6e4 differs from pull request most recent head 69f459e

Please upload reports for the commit 69f459e to get more accurate results.

:exclamation: There is a different number of reports uploaded between BASE (337f072) and HEAD (c55e6e4). Click for more details.

HEAD has 1 upload more than BASE | Flag | BASE (337f072) | HEAD (c55e6e4) | |------|------|------| |unittests|1|2|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1176 +/- ## =========================================== - Coverage 84.53% 73.03% -11.50% =========================================== Files 94 93 -1 Lines 7571 7477 -94 =========================================== - Hits 6400 5461 -939 - Misses 1171 2016 +845 ``` | [Flag](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | `73.03% <85.26%> (-11.50%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev) | Coverage Δ | | |---|---|---| | [sbi/inference/posteriors/direct\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fdirect_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL2RpcmVjdF9wb3N0ZXJpb3IucHk=) | `98.78% <100.00%> (-0.02%)` | :arrow_down: | | [sbi/inference/posteriors/mcmc\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fmcmc_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL21jbWNfcG9zdGVyaW9yLnB5) | `86.58% <100.00%> (+0.38%)` | :arrow_up: | | [sbi/neural\_nets/density\_estimators/nflows\_flow.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Fneural_nets%2Fdensity_estimators%2Fnflows_flow.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL25ldXJhbF9uZXRzL2RlbnNpdHlfZXN0aW1hdG9ycy9uZmxvd3NfZmxvdy5weQ==) | `63.46% <ø> (ø)` | | | [sbi/inference/posteriors/base\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fbase_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL2Jhc2VfcG9zdGVyaW9yLnB5) | `83.72% <66.66%> (-2.33%)` | :arrow_down: | | [sbi/inference/posteriors/importance\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fimportance_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL2ltcG9ydGFuY2VfcG9zdGVyaW9yLnB5) | `55.38% <50.00%> (-30.77%)` | :arrow_down: | | [sbi/inference/posteriors/rejection\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Frejection_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL3JlamVjdGlvbl9wb3N0ZXJpb3IucHk=) | `80.95% <50.00%> (-11.74%)` | :arrow_down: | | [sbi/inference/posteriors/vi\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fvi_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL3ZpX3Bvc3Rlcmlvci5weQ==) | `80.86% <50.00%> (-9.88%)` | :arrow_down: | | [sbi/samplers/rejection/rejection.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Fsamplers%2Frejection%2Frejection.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL3NhbXBsZXJzL3JlamVjdGlvbi9yZWplY3Rpb24ucHk=) | `88.00% <92.85%> (ø)` | | | [sbi/inference/posteriors/ensemble\_posterior.py](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176?src=pr&el=tree&filepath=sbi%2Finference%2Fposteriors%2Fensemble_posterior.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev#diff-c2JpL2luZmVyZW5jZS9wb3N0ZXJpb3JzL2Vuc2VtYmxlX3Bvc3Rlcmlvci5weQ==) | `50.00% <11.11%> (-37.97%)` | :arrow_down: | ... and [33 files with indirect coverage changes](https://app.codecov.io/gh/sbi-dev/sbi/pull/1176/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=sbi-dev)
gmoss13 commented 1 week ago

I've made some progress now towards this PR, and would like some feedback before I continue.

BasePotential can either "allow_iid" or not.

Given batch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluate potential(x,theta). We could return (batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume that batch_dim_theta is a multiple of batch_dim_x (i.e. for sampling, we have n chains in theta for each x). In this case we expand the batch dim of x to batch_theta, and match which x goes to which theta. If we are happy with this approach, I'll go ahead and apply this also to the MCMC init_strategy, etc., and make sure this is consistent with other calls.

Remove warning for batched x and default to batched evaluation Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation when sample_batched or log_prob_batched is called, and otherwise assume iid (and warn if batch dim >1 as before).

manuelgloeckler commented 1 week ago

Great, it looks good. I like that the choice on iid or not can now be made at the set_x method which makes a lot of sense.

I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).

For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.