bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

interpret predictions enhancements #736

Closed GStechschulte closed 1 year ago

GStechschulte commented 1 year ago

This PR adds new functionality to the predictions and plot_predictions functions in interpret and resolves #735. Users can now

Previously, users could only pass a string or list of covariates to compute conditional adjusted predictions. Now, predictions has "most of" the functionality that {marginaleffects} has. Additionally, the changes result in a more standard API when calling comparisons, predictions, and slopes. Each function has the arg. conditional in which the user can "condition" their estimates on. Furthermore, each function can now compute:

Below, you will find a couple demos:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings

import bambi as bmb

data = bmb.load_data('mtcars')
data["cyl"] = data["cyl"].replace({4: "low", 6: "medium", 8: "high"})
data["gear"] = data["gear"].replace({3: "A", 4: "B", 5: "C"})
data["cyl"] = pd.Categorical(data["cyl"], categories=["low", "medium", "high"], ordered=True)

model = bmb.Model("mpg ~ 0 + hp * wt + cyl + gear", data)
idata = model.fit(draws=1000, target_accept=0.95, random_seed=1234)

# unit-level predictions
bmb.interpret.predictions(
    model,
    idata,
    conditional=None
)
cyl gear hp wt estimate lower_3.0% upper_97.0%
medium B 110 2.620 22.233424 20.051966 24.476544
medium B 110 2.875 21.320402 19.196886 23.344765
low B 93 2.320 25.901435 24.255648 27.558330
medium A 110 3.215 18.751708 16.259737 21.293185
high A 175 3.440 16.908354 15.261489 18.666662

Unit level predictions and average over hp and wt to obtain marginal effects of gear and cyl:

bmb.interpret.plot_predictions(
    model,
    idata,
    conditional=None,
    average_by=["gear", "cyl"],
    fig_kwargs={"figsize": (7, 3)},
);

image

Compute a pairwise grid using user-provided values and compute predictions:

bmb.interpret.plot_predictions(
    model,
    idata,
    conditional={
        "hp": [100, 120],
        "cyl": ["low", "medium", "high"],
        "gear": "A",
    },
    subplot_kwargs={"main": "hp", "group": "gear", "panel": "cyl"},
    fig_kwargs={"figsize": (10, 4), "sharey": True},
    legend=True
);

image

To do:

review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

GStechschulte commented 1 year ago

@tomicapretto although it is ~660 lines of code added 😵‍💫 a lot of it comes from docstring additions, more error handling, and tests. Most of the added functionality leveraged existing functions.

tomicapretto commented 1 year ago

@GStechschulte being 100% honest I can't understand all the changes as well as you do. I do see in many cases you made things more general and you are reusing code more, which is great. So once you finish the implementation, I trust your judgment to merge this.

Two more comments:

  1. The test failure is because of some non-compatible type hint on Python 3.9. I _think it's fixed here but I'm not sure.
  2. Do you think we can cut a 0.13.0 release after this? I just realized we don't have a release with the interpret submodule and I think it's quite mature at this point.
GStechschulte commented 1 year ago

being 100% honest I can't understand all the changes as well as you do

Mmm. That's not a good sign in my opinion. Is it the code diff that you are unsure about, or what has changed in the predictions functionality?

The test failure is because of some non-compatible type hint on Python 3.9. I _think it's fixed https://github.com/pymc-devs/pymc/pull/6945 but I'm not sure.

It has been resolved.

Do you think we can cut a 0.13.0 release after this? I just realized we don't have a release with the interpret submodule and I think it's quite mature at this point.

Yup, I think we can go ahead with that 👍🏼

tomicapretto commented 1 year ago

Mmm. That's not a good sign in my opinion. Is it the code diff that you are unsure about, or what has changed in the predictions functionality?

Oh no, I don't mean this in a bad way at all. What I'm saying is that the submodule grew a lot, for good reasons, and I'm not as familiar with everything as you are. So I can only provide a high-level review without getting deep into details because it would take much more time.

tomicapretto commented 1 year ago

Looks great! Go ahead and merge if you don't plan to add anything.

jt-lab commented 1 year ago

@GStechschulte, many thanks for the work on the submodule, it really helps a lot! I can report that the preliminary average_by solutions already worked very well not only on the simulated data set but also on others. The predictions align well with the observed data.

GStechschulte commented 1 year ago

@jt-lab Thank you and this is great to hear! 😄 We / I really appreciate you taking the time to open the issues and to give feedback. Cheers!