Closed jt-lab closed 1 year ago
Hey @jt-lab thanks a lot for raising the issue and sharing the code / dataset!
At a quick glance, this is because in the random effects model, individual
is included as a term in the Bambi model, and since this term is not specified in plot_predictions
, a default value of individual=0
is being computed (since it is categorical, Bambi takes the mode).
bmb.interpret.predictions(
model=model,
idata=idata,
covariates=["factor3", "factor2", "factor1"],
)
factor3 | factor2 | factor1 | individual | estimate | lower_3.0% | upper_97.0% |
---|---|---|---|---|---|---|
1 | X | A | 0 | 0.370441 | 0.331295 | 0.407852 |
2 | X | A | 0 | 0.545766 | 0.504590 | 0.586051 |
1 | Y | A | 0 | 0.301027 | 0.268440 | 0.337534 |
... | ... | ... | ... | ... | ... | .. |
2 | X | C | 0 | 0.453661 | 0.426735 | 0.480137 |
1 | Y | C | 0 | 0.229191 | 0.208711 | 0.249756 |
2 | Y | C | 0 | 0.383008 | 0.360000 | 0.407484 |
Thus, your first plot above seems to be comparing individual=0
predictions with the population level data points. Then, since the second model does not include a random effect (and thus no individual
term), the predictions more closely match the population level data points (as you stated).
I have a couple more comments about this, but don't have the time this morning. I will communicate here in the next day. Thanks!
@GStechschulte, thanks a lot! That makes sense. So perhaps we should predict a single out-of-sample individual in this situation.
A quick update: the hdi bar issue described above was unrelated to this...
Looking forward to your further comments!
Many thanks for the support!
Of course, anytime! I appreciate getting the feedback. Regarding
the hdi bar issue described above was unrelated to this...
I am looking into this. Thanks for also pointing this out.
plot_predictions
was the first interpret
plotting function developed in early 2023. Over summer 2023, I added plot_comparisons
, and plot_slopes
. In the latter two functions, it is possible to predict using the observed (empirical) data. These are "unit-level" predictions. Additionally, it is possible to pass your own values and or to create a grid of values to use as the data fed to the model to perform predictions.
However, plot_predictions
does not have this additional functionality yet. In marginaleffects, it does. Thus, once I add the ability to perform unit-level predictions, you could achieve the "desired / correct" plot with plot_predictions
. Below, I give you an example in {marginaleffects} using your data and the random effects model:
library(brms)
library(marginaleffects)
dat <- read.csv("simulated_data_order_prob2.csv")
dat$factor3 = as.factor(dat$factor3)
dat$factor2 = as.factor(dat$factor2)
dat$factor1 = as.factor(dat$factor1)
formula <- bf(correct | trials(count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual))
model <- brm(formula, data = dat, family = binomial)
# unit-level predictions averaged over individuals
plot_predictions(
model,
by=c("factor3", "factor2", "factor1")
)
The plot shows the marginal estimates as it was averaged over all individuals.
@tomicapretto what do you think? All the pieces are there in interpret
. It is just a matter of putting it together. Then, the predictions
and plot_predictions
functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.
Since the majority of the functions are there. Here's a working demo in Bambi now:
model = bmb.Model(
"p(correct, count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual)",
data,
family="binomial",
categorical=["factor3", "individual", "factor1", "factor2"],
priors=priors,
noncentered=False
)
idata = model.fit(tune=2000, draws=2000, random_seed=123, init='adapt_diag',
target_accept=0.9, idata_kwargs={'log_likelihood':True})
bmb.interpret.plot_predictions(
model=model,
idata=idata,
average_by=["factor3", "factor2", "factor1"],
fig_kwargs={"figsize": (12, 4), "sharey": True},
)
Side note: I know the {marginaleffects} plot and the Bambi plot aren't the same. I would take the {marginaleffects} plot above with caution as I quickly did that and there were divergences, etc. It was used as an example implementation.
Thank you so much, @GStechschulte!
I just came here to thank you for your previous post with the explanations and examples! But this is of course even greater! So do I get this right that using the average_by specification instead of the covariates does the trick?
By the way, what do you think of these ideas:
If you like I could try implementing these!
Many thanks again
@jt-lab thank you! π
So do I get this right that using the average_by specification instead of the covariates does the trick?
Right. Not passing any variables into covariates
results in unit-level predictions. Then, since a prediction is made for each individual, a pd.groupby(average_by).mean()
is applied to group by the factors (the variables passed to average_by
) and then .mean()
to compute the marginal effect for: factor3
, factor2
, factor1
.
Add (optional) verbosity when plots are created. E.g., print out some info when defaults are applied or averaging etc. happens implicitly
I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?
Add possibility to plot observations along with predictions (like I did manually above).
I had not thought of this until I saw you do it. At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots). Unless more users ask for this, I think I personally won't pursue it. Nonetheless, I liked your solution with seaborn π
Thanks for the ideas! ππΌ
@GStechschulte,
We just wanted to try the average_by solution but there is no argument average_by in plot_predictions. Also I don't see it in the docs or code on github. Even in your fork it's not there. So maybe I misunderstood that this was an already existing workaround? Or is there some secret branch it is on? :-D
At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots).
I see, makes sense.
Nonetheless, I liked your solution with seaborn π
Yes that works okayish. We had some order-related trouble again as seaborn creates the order depending on the order in the data (if not specified otherwise) and plot_prediction uses a different order. So one has to watch out for that.
@jt-lab you are too quick for me π haha. I just pushed these changes in this branch.
Please note I still need to add error handling and tests to ensure everything works.
Cheers!
@jt-lab you are too quick for me π haha. I just pushed these changes in this branch.
Please note I still need to add error handling and tests to ensure everything works.
Cheers!
Thanks and sorry for the impatience π¬
I thought you remembered an existing workaround
@jt-lab I just pushed some more changes to that branch and opened draft PR #736
@tomicapretto what do you think? All the pieces are there in interpret. It is just a matter of putting it together. Then, the predictions and plot_predictions functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.
Just seeing the issue. Thanks for proactively writing the code and opening the PR :)
I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?
I like the idea too! I only want to make 2 points
I think one possible approach is to have a configuration instance that comes with a default option and users can do something like bmb.config.interpret_messages = False
. This is what we have in formulae, check https://github.com/bambinos/formulae/blob/master/formulae/config.py and https://github.com/bambinos/formulae/blob/master/tests/test_config.py
Add possibility to plot observations along with predictions (like I did manually above)
@jt-lab I agree with @GStechschulte's response here. It's already a lot to maintain the existing functionality. Anything related to observed data should be deferred to the user.
@GStechschulte let's continue any related discussion in the PR where you implement the changes :)
@GStechschulte, we have further played around with plot_predictions and came across some behavior we don't understand . It might be an issue with handling random effects, hence I describe it here:
For a model with random effects (p(correct, count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual)), the predictions seem to be off compared to the data points (e.g. see Factor1=A, Factor2=orange, Factor3=1 but also others):
Also, in factor C, the hdi bars get a bit smaller. It's barely visible here, but in another data set (which I cannot share), they are about 3 times smaller than those of levels A and B, apparently without any reason related to the data.
The same model but without the random effects produces predictions which are pretty close to the empirical means:
Many thanks in advance!
Code to reproduce this and the dataset:
Data Set:
simulated_data_order_prob2.csv