bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

Return interpret data #758

Closed GStechschulte closed 11 months ago

GStechschulte commented 1 year ago

This PR addresses issue https://github.com/bambinos/bambi/issues/703 and #751 by adding a parameter return_idata: bool = False in comparisons(), predictions(), and slopes() that merges the posterior draws with the corresponding observation that "produced" that draw and returns it as a dataframe.

Most of the code diff is from adding a new test file that tests non-plotting functionality of the interpret sub-package not tested in test_plots.py.

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])

fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(
    draws=1000, 
    target_accept=0.95, 
    random_seed=1234, 
    chains=4
)

With return_idata=True, one data frame is returned. This dataframe contains the inference data from the posterior groupInferenceData object, observed data, and parameter estimates. In the case that a user is calling predictions with pps=True, then the posterior predictive group is used. {marginaleffects} has a similar functionality for Bayesian models.

Below are a few examples:

bmb.interpret.predictions(
    model=fish_model,
    idata=fish_idata,
    conditional=["persons", "child", "livebait"],
    return_idata=True
) 
chain draw livebait_dim camper_dim Intercept livebait camper persons child count_psi count_mean persons_obs child_obs livebait_obs camper_obs
0 0 1.0 1.0 -2.560515 1.877271 0.658819 0.850313 -1.256469 0.619094 0.349454 1.0 0.0 0.0 1.0
0 1 1.0 1.0 -2.746079 1.852347 0.794963 0.848145 -1.302583 0.643842 0.331884 1.0 0.0 0.0 1.0
0 2 1.0 1.0 -2.669619 1.674642 0.693835 0.923590 -1.588417 0.683065 0.349171 1.0 0.0 0.0 1.0
0 3 1.0 1.0 -2.581749 1.474203 0.661846 0.958753 -1.624487 0.684139 0.382453 1.0 0.0 0.0 1.0
0 4 1.0 1.0 -2.866501 2.162873 0.575769 0.860328 -1.311473 0.647937 0.239212 1.0 0.0 0.0 1.0

1200000 rows × 15 columns

Returning the inference data when calling comparisons will allow the user to conduct more specific or complex comparisons leveraging group by aggregations:

bmb.interpret.comparisons(
    model=fish_model,
    idata=fish_idata,
    contrast={"persons": [1, 4]},
    conditional={"child": [0, 1, 2], "livebait": [0, 1]},
    return_idata=True
) 
chain draw livebait_dim camper_dim Intercept livebait camper persons child count_psi count_mean child_obs livebait_obs persons_obs camper_obs
0 0 1.0 1.0 -2.560515 1.877271 0.658819 0.850313 -1.256469 0.619094 0.349454 0.0 0 1.0 1.0
0 1 1.0 1.0 -2.746079 1.852347 0.794963 0.848145 -1.302583 0.643842 0.331884 0.0 0 1.0 1.0
0 2 1.0 1.0 -2.669619 1.674642 0.693835 0.923590 -1.588417 0.683065 0.349171 0.0 0 1.0 1.0
0 3 1.0 1.0 -2.581749 1.474203 0.661846 0.958753 -1.624487 0.684139 0.382453 0.0 0 1.0 1.0
0 4 1.0 1.0 -2.866501 2.162873 0.575769 0.860328 -1.311473 0.647937 0.239212 0.0 0 1.0 1.0

48000 rows × 15 columns

Initially, I wanted to return the az.InferenceData object. However, due to the following limitations I settled on a DataFrame:

TypeError: group must be an xarray.DataArray or the name of an xarray variable or dimension. Received ['coord1', 'coord2'] instead.

Note: depending on the model specification and number of chains and draws, it is possible there will be millions of rows returned.

To do:

codecov-commenter commented 1 year ago

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (dcd879b) 89.90% compared to head (6208fe8) 89.91%.

Files Patch % Lines
bambi/interpret/utils.py 90.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #758 +/- ## ========================================== + Coverage 89.90% 89.91% +0.01% ========================================== Files 45 45 Lines 3713 3729 +16 ========================================== + Hits 3338 3353 +15 - Misses 375 376 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

GStechschulte commented 1 year ago

Thanks for the nice feature! Just a couple of suggestions.

Thanks for the review. I will incorporate these once we finalize the implementation per our conversation on Slack.

GStechschulte commented 11 months ago

Closing in favor of #762