bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

Posterior draws for comparisons(), predictions(), and slopes() functions #703

Closed zwelitunyiswa closed 11 months ago

zwelitunyiswa commented 1 year ago

@tomicapretto @GStechschulte Thanks for your hard work on this module of Bambi. I was playing around with it this past week and I am mighty impressed! I love that I can create a data frame with summaries that I can plot in the Bambi plotting module or any plotting library I prefer. That is huge for reporting.

I was wondering if it would be possible to add a function that returns the underlying posterior draws for the summaries for the comparisons(), predictions(), and slopes() functions, like you can do in marginaleffects with the posterior_draws() function?

That would be super useful for reporting probabilities, ROPE, etc.

tomicapretto commented 1 year ago

Hi @zwelitunyiswa thanks for the nice words! All credit belongs to Gabriel :)

@GStechschulte do you think that would be possible?

GStechschulte commented 1 year ago

Hey @zwelitunyiswa, thanks a lot for the kind words and feedback. Yeah, that is possible. I admit, I have yet to use the posterior_draws() function in margineffects. I will use it and see how they return the posterior draws.

My initial thoughts are to either: (1) use a function to return an az.InferenceData object or (2) add an argument in comparisons(), predictions(), and slopes() such as return_idata=True that would return an az.InferenceData object.

I like the use of the inference data object since Bambi and PyMC users likely have experience with such data structures, and due to the n-dimensional nature of the posterior draws.

Let me think some more and do some prototyping, and I will communicate back here. Thanks!

GStechschulte commented 1 year ago

Okay @zwelitunyiswa, so I have used the marginaleffects posterior_draws() function and believe the simple thing to do is to have an optional return_posterior arg in comparisons(), predictions(), and slopes() that merges the posterior draws with the corresponding observation that "produced" that draw and returns a dataframe. Below you can find an example:

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])

fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(
    draws=1000, 
    target_accept=0.95, 
    random_seed=1234, 
    chains=4
)

With posterior_draws=True, two data frames are returned. First, the summary data frame (no change here), and second, the posterior draws flattened and merged with the data used to generate the predictions.

summary_df, post_draws_df = comparisons(
    model=fish_model,
    idata=fish_idata,
    contrast={"persons": [1, 4]},
    conditional={"child": [0, 1, 2], "livebait": [0, 1]},
    return_posterior=True
)

post_draws_df
chain draw livebait_dim camper_dim Intercept livebait camper persons child count_psi count_mean child_obs livebait_obs persons_obs camper_obs
0 0 1.0 1.0 -2.338197 1.592060 0.653525 0.843256 -1.284886 0.595906 0.431100 0.0 0.0 1.0 1.0
0 1 1.0 1.0 -2.305775 1.538601 0.690644 0.855629 -1.326705 0.570563 0.467899 0.0 0.0 1.0 1.0
0 2 1.0 1.0 -2.078147 1.368352 0.742767 0.821930 -1.226769 0.623266 0.598427 0.0 0.0 1.0 1.0
0 3 1.0 1.0 -2.790560 2.066620 0.542978 0.875392 -1.345364 0.541648 0.253551 0.0 0.0 1.0 1.0
0 4 1.0 1.0 -2.557727 1.797058 0.677152 0.856107 -1.414570 0.632578 0.358988 0.0 0.0 1.0 1.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
3 995 1.0 1.0 -2.796267 1.891494 0.719850 0.900202 -1.393952 0.602064 1.873827 2.0 1.0 4.0 1.0
3 996 1.0 1.0 -2.675291 1.730114 0.732860 0.886734 -1.188741 0.583542 2.604234 2.0 1.0 4.0 1.0
3 997 1.0 1.0 -2.389797 1.695540 0.602298 0.857744 -1.504039 0.633454 1.392279 2.0 1.0 4.0 1.0
3 998 1.0 1.0 -2.556063 1.641193 0.721424 0.882217 -1.259443 0.579240 2.262654 2.0 1.0 4.0 1.0
3 999 1.0 1.0 -2.663126 1.902570 0.519382 0.900782 -1.361077 0.614081 1.896101 2.0 1.0 4.0 1.0

where the columns are:

What do you think? (feel free to comment @tomicapretto @aloctavodia)

tomicapretto commented 1 year ago

I think this solution is awesome!

zwelitunyiswa commented 1 year ago

This is spot on. It’s a nice elegant solution. Well done, mate.

On Sat, Aug 19, 2023 at 07:49 Tomás Capretto @.***> wrote:

I think this solution is awesome!

— Reply to this email directly, view it on GitHub https://github.com/bambinos/bambi/issues/703#issuecomment-1685021496, or unsubscribe https://github.com/notifications/unsubscribe-auth/AH3QQV5GFOXAKX3RGVQPOCDXWDG73ANCNFSM6AAAAAA3PYZ3B4 . You are receiving this because you were mentioned.Message ID: @.***>

GStechschulte commented 1 year ago

Thanks for the quick feedback @tomicapretto and @zwelitunyiswa 👍🏼 I will make a PR in the next days. Cheers!