bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 123 forks source link

plot_predictions with random effects #735

Closed jt-lab closed 1 year ago

jt-lab commented 1 year ago

@GStechschulte, we have further played around with plot_predictions and came across some behavior we don't understand . It might be an issue with handling random effects, hence I describe it here:

For a model with random effects (p(correct, count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual)), the predictions seem to be off compared to the data points (e.g. see Factor1=A, Factor2=orange, Factor3=1 but also others):

w_rf

Also, in factor C, the hdi bars get a bit smaller. It's barely visible here, but in another data set (which I cannot share), they are about 3 times smaller than those of levels A and B, apparently without any reason related to the data.

The same model but without the random effects produces predictions which are pretty close to the empirical means:

wo_rf

Many thanks in advance!

Code to reproduce this and the dataset:

import pandas as pd
import bambi as bmb
import seaborn as sns

data = pd.read_csv('simulated_data_order_prob2.csv')

priors = {
    "factor3:factor2": bmb.Prior("Normal", mu=0, sigma=1),
    "factor1": bmb.Prior("Normal", mu=0, sigma=1),
    "factor3:factor2|individual": bmb.Prior("Normal", mu=0, sigma=bmb.Prior("Gamma", alpha=3, beta=3))
}

model = bmb.Model(
    "p(correct, count) ~ 0 + factor3:factor2 + factor1  + (0 + factor3:factor2 | individual)",
    data,
    family="binomial",
    categorical=["factor3", "individual", "factor1", "factor2"],
    priors=priors,
    noncentered=False)

idata = model.fit(tune=2000, draws=2000, random_seed=123, init='adapt_diag',
                  target_accept=0.9, idata_kwargs={'log_likelihood':True})

data['report frequency (%)'] = data['correct'] / data['count']

g = sns.catplot(data=data, kind='strip', x='factor3', y='report frequency (%)', 
                hue='factor2', col='factor1', jitter=False, dodge=True)

axs = bmb.interpret.plot_predictions(
    model=model, 
    idata=idata, 
    covariates=["factor3", "factor2", "factor1"],
    pps=False,
    legend=True,
    fig_kwargs={"figsize": (20, 8), "sharey": True},
    prob = .95,
    ax=g.axes
)

Data Set:

simulated_data_order_prob2.csv

GStechschulte commented 1 year ago

Hey @jt-lab thanks a lot for raising the issue and sharing the code / dataset!

At a quick glance, this is because in the random effects model, individual is included as a term in the Bambi model, and since this term is not specified in plot_predictions, a default value of individual=0 is being computed (since it is categorical, Bambi takes the mode).

bmb.interpret.predictions(
    model=model, 
    idata=idata, 
    covariates=["factor3", "factor2", "factor1"],
)
factor3 factor2 factor1 individual estimate lower_3.0% upper_97.0%
1 X A 0 0.370441 0.331295 0.407852
2 X A 0 0.545766 0.504590 0.586051
1 Y A 0 0.301027 0.268440 0.337534
... ... ... ... ... ... ..
2 X C 0 0.453661 0.426735 0.480137
1 Y C 0 0.229191 0.208711 0.249756
2 Y C 0 0.383008 0.360000 0.407484

Thus, your first plot above seems to be comparing individual=0 predictions with the population level data points. Then, since the second model does not include a random effect (and thus no individual term), the predictions more closely match the population level data points (as you stated).

I have a couple more comments about this, but don't have the time this morning. I will communicate here in the next day. Thanks!

jt-lab commented 1 year ago

@GStechschulte, thanks a lot! That makes sense. So perhaps we should predict a single out-of-sample individual in this situation.

A quick update: the hdi bar issue described above was unrelated to this...

Looking forward to your further comments!

Many thanks for the support!

GStechschulte commented 1 year ago

Of course, anytime! I appreciate getting the feedback. Regarding

the hdi bar issue described above was unrelated to this...

I am looking into this. Thanks for also pointing this out.

plot_predictions was the first interpret plotting function developed in early 2023. Over summer 2023, I added plot_comparisons, and plot_slopes. In the latter two functions, it is possible to predict using the observed (empirical) data. These are "unit-level" predictions. Additionally, it is possible to pass your own values and or to create a grid of values to use as the data fed to the model to perform predictions.

However, plot_predictions does not have this additional functionality yet. In marginaleffects, it does. Thus, once I add the ability to perform unit-level predictions, you could achieve the "desired / correct" plot with plot_predictions. Below, I give you an example in {marginaleffects} using your data and the random effects model:

library(brms)
library(marginaleffects)

dat <- read.csv("simulated_data_order_prob2.csv")

dat$factor3 = as.factor(dat$factor3)
dat$factor2 = as.factor(dat$factor2)
dat$factor1 = as.factor(dat$factor1)
formula <- bf(correct | trials(count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual))
model <- brm(formula, data = dat, family = binomial)

# unit-level predictions averaged over individuals
plot_predictions(
  model,
  by=c("factor3", "factor2", "factor1")
)

image

The plot shows the marginal estimates as it was averaged over all individuals.

@tomicapretto what do you think? All the pieces are there in interpret. It is just a matter of putting it together. Then, the predictions and plot_predictions functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.

GStechschulte commented 1 year ago

Since the majority of the functions are there. Here's a working demo in Bambi now:

model = bmb.Model(
    "p(correct, count) ~ 0 + factor3:factor2 + factor1  + (0 + factor3:factor2 | individual)",
    data,
    family="binomial",
    categorical=["factor3", "individual", "factor1", "factor2"],
    priors=priors,
    noncentered=False
)

idata = model.fit(tune=2000, draws=2000, random_seed=123, init='adapt_diag',
                  target_accept=0.9, idata_kwargs={'log_likelihood':True})

bmb.interpret.plot_predictions(
    model=model, 
    idata=idata,
    average_by=["factor3", "factor2", "factor1"],
    fig_kwargs={"figsize": (12, 4), "sharey": True},
)

image

Side note: I know the {marginaleffects} plot and the Bambi plot aren't the same. I would take the {marginaleffects} plot above with caution as I quickly did that and there were divergences, etc. It was used as an example implementation.

jt-lab commented 1 year ago

Thank you so much, @GStechschulte!

I just came here to thank you for your previous post with the explanations and examples! But this is of course even greater! So do I get this right that using the average_by specification instead of the covariates does the trick?

By the way, what do you think of these ideas:

If you like I could try implementing these!

Many thanks again

GStechschulte commented 1 year ago

@jt-lab thank you! πŸ˜„

So do I get this right that using the average_by specification instead of the covariates does the trick?

Right. Not passing any variables into covariates results in unit-level predictions. Then, since a prediction is made for each individual, a pd.groupby(average_by).mean() is applied to group by the factors (the variables passed to average_by) and then .mean() to compute the marginal effect for: factor3, factor2, factor1.

Add (optional) verbosity when plots are created. E.g., print out some info when defaults are applied or averaging etc. happens implicitly

I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?

Add possibility to plot observations along with predictions (like I did manually above).

I had not thought of this until I saw you do it. At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots). Unless more users ask for this, I think I personally won't pursue it. Nonetheless, I liked your solution with seaborn πŸ˜„

Thanks for the ideas! πŸ‘πŸΌ

jt-lab commented 1 year ago

@GStechschulte,

We just wanted to try the average_by solution but there is no argument average_by in plot_predictions. Also I don't see it in the docs or code on github. Even in your fork it's not there. So maybe I misunderstood that this was an already existing workaround? Or is there some secret branch it is on? :-D

At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots).

I see, makes sense.

Nonetheless, I liked your solution with seaborn πŸ˜„

Yes that works okayish. We had some order-related trouble again as seaborn creates the order depending on the order in the data (if not specified otherwise) and plot_prediction uses a different order. So one has to watch out for that.

GStechschulte commented 1 year ago

@jt-lab you are too quick for me πŸ˜‰ haha. I just pushed these changes in this branch.

Please note I still need to add error handling and tests to ensure everything works.

Cheers!

jt-lab commented 1 year ago

@jt-lab you are too quick for me πŸ˜‰ haha. I just pushed these changes in this branch.

Please note I still need to add error handling and tests to ensure everything works.

Cheers!

Thanks and sorry for the impatience 😬

I thought you remembered an existing workaround

GStechschulte commented 1 year ago

@jt-lab I just pushed some more changes to that branch and opened draft PR #736

tomicapretto commented 1 year ago

@tomicapretto what do you think? All the pieces are there in interpret. It is just a matter of putting it together. Then, the predictions and plot_predictions functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.

Just seeing the issue. Thanks for proactively writing the code and opening the PR :)

I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?

I like the idea too! I only want to make 2 points

I think one possible approach is to have a configuration instance that comes with a default option and users can do something like bmb.config.interpret_messages = False. This is what we have in formulae, check https://github.com/bambinos/formulae/blob/master/formulae/config.py and https://github.com/bambinos/formulae/blob/master/tests/test_config.py

Add possibility to plot observations along with predictions (like I did manually above)

@jt-lab I agree with @GStechschulte's response here. It's already a lot to maintain the existing functionality. Anything related to observed data should be deferred to the user.

@GStechschulte let's continue any related discussion in the PR where you implement the changes :)