Open NathanielF opened 7 months ago
I imagine there is quite a bit of work in adapting the interpret function to account for the static spline bases but i'm just wondering if there is something obvious i'm missing?
Hey @NathanielF thanks for raising the issue. This should work. What version of Bambi are you using?
A couple months ago we added enhancements of how the interpret package parses hsgp and basis spline terms in the Bambi model. It may be that you need to install from source to get these latest changes.
Nonetheless, I will attempt to reproduce your example and get back with you.
Thanks @GStechschulte I'm using a 0.13:
I'd be keen to discover if i can use Bambi's native interpret function.
For the moment i've hacked to a routine to extract the components:
Bincome = spline_model.response_component.design.common['bs(Income, degree=3, knots=knots_income)']
income_coefs = az.extract(spline_idata['posterior']['bs(Income, degree=3, knots=knots_income)'])['bs(Income, degree=3, knots=knots_income)']
Bedu = spline_model.response_component.design.common['bs(Edu, degree=3, knots=knots_edu)']
edu_coefs = az.extract(spline_idata['posterior']['bs(Edu, degree=3, knots=knots_edu)'])['bs(Edu, degree=3, knots=knots_edu)']
Bhealth = spline_model.response_component.design.common['bs(Health, degree=3, knots=knots_health)']
health_coefs = az.extract(spline_idata['posterior']['bs(Health, degree=3, knots=knots_health)'])['bs(Health, degree=3, knots=knots_health)']
income = np.dot(Bincome, income_coefs).T
edu = np.dot(Bedu, edu_coefs).T
health = np.dot(Bhealth, health_coefs).T
intercept = az.extract(spline_idata['posterior']['Intercept'])['Intercept'].values
fig, ax = plt.subplots(figsize=(9, 7))
for i in range(100):
if i == 1:
ax.plot(income[i], label='Income Component', color='red')
ax.plot(edu[i], label='Edu Component', color='blue')
ax.plot(health[i], label='Health Component', color='darkgreen')
ax.plot(intercept[i] + income[i] + edu[i] + health[i], label='Combined Components', color='purple')
else:
ax.plot(income[i], alpha=0.1, color='red')
ax.plot(edu[i], alpha=0.1, color='blue')
ax.plot(health[i], alpha=0.1, color='darkgreen')
ax.plot(intercept[i] + income[i] + edu[i] + health[i], color='purple', alpha=0.3)
ax.scatter(range(len(spline_idata['observed_data']['Overall'])), spline_idata['observed_data']['Overall'], label='Observed', color='grey', s=56, ec='black')
ax.set_title("Additive Spline Components", fontsize=20)
ax.legend();
ax.set_xticklabels(pisa_df.dropna(axis=0).reset_index()['Country'])
which gives me this plot:
Work in Progress here: https://nathanielf.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.html
Diving into both splines, hierarchical spline modelling and HSGP approximations using Bambi. Would be great if i could "show-off" the full functionality.
@NathanielF @GStechschulte I'm working on this, I think the issue is where looking up some variable in the data frame when that should not be the case.
@GStechschulte I see in the code @NathanielF shared he's setting wrt="Overall"
, when "Overall"
is the response variable. I think that's not the intended usage of the argument, right?
@NathanielF are you trying to use plot_slopes()
or plot_predictions()
? Just notice I'm not an expert on plot_slopes()
(Gabriel knows much more about this than me). But I'm guessing what you want is plot_predictions()
.
Ah ok. That maybe it. I will have another look!
It's still broken in the latest release and the main branch, I'm about to push a branch with a quick fix and open an issue explaining the problem. Thanks @NathanielF for raising this, it allowed us to find a bug :)
Nice! Happy to help
Is this something close to what you're looking for?
fig, ax = bmb.interpret.plot_predictions(
spline_model, spline_idata, "Edu"
)
fig, ax = bmb.interpret.plot_predictions(
spline_model, spline_idata, {"Edu": np.linspace(0.7, 1, num=100)}
)
@NathanielF this is the branch https://github.com/tomicapretto/bambi/tree/spline_interpret and this is a notebook with the code https://github.com/tomicapretto/bambi/blob/spline_interpret/spline_prediction.ipynb (I'll delete it before merging into Bambi's main)
The main point is that in conditional
we pass name of variables, not name of terms. That is because variables can play a role in many terms. For example, 1 + x + x:z
. The same happens when you use a transformation log(x), exp(x) or bs(x)
. We only need to pass the name of the covariate, not the name of the term.
The plot predictions code works for me too. But just out of curiosity i tried the plot slopes without the Overall
variable.
fig, ax = bmb.interpret.plot_slopes(
spline_model,
spline_idata,
wrt="Health",
conditional={
"Income": np.linspace(0.1, 0.5, 10),
}
)
I've change these to two covariates and i'm now getting a different error:
{
"name": "TypeError",
"message": "unhashable type: 'numpy.ndarray'",
"stack": "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)\n\u001b[1;32m/Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd\u001b[0m in \u001b[0;36mline 1\n\u001b[0;32m----> <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=962'>963</a>\u001b[0m fig, ax \u001b[39m=\u001b[39m bmb\u001b[39m.\u001b[39;49minterpret\u001b[39m.\u001b[39;49mplot_slopes(\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=963'>964</a>\u001b[0m spline_model,\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=964'>965</a>\u001b[0m spline_idata,\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=965'>966</a>\u001b[0m wrt\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mHealth\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=966'>967</a>\u001b[0m conditional\u001b[39m=\u001b[39;49m{\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=967'>968</a>\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mIncome\u001b[39;49m\u001b[39m\"\u001b[39;49m: np\u001b[39m.\u001b[39;49mlinspace(\u001b[39m0.1\u001b[39;49m, \u001b[39m0.5\u001b[39;49m, \u001b[39m10\u001b[39;49m),\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=968'>969</a>\u001b[0m }\n\u001b[1;32m <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=969'>970</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py:522\u001b[0m, in \u001b[0;36mplot_slopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, sample_new_groups, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=517'>518</a>\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39mslope\u001b[39m\u001b[39m'\u001b[39m\u001b[39m must be one of (\u001b[39m\u001b[39m'\u001b[39m\u001b[39mdydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39mdyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=519'>520</a>\u001b[0m conditional_info \u001b[39m=\u001b[39m ConditionalInfo(model, conditional)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=521'>522</a>\u001b[0m slopes_summary \u001b[39m=\u001b[39m slopes(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=522'>523</a>\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=523'>524</a>\u001b[0m idata\u001b[39m=\u001b[39;49midata,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=524'>525</a>\u001b[0m wrt\u001b[39m=\u001b[39;49mwrt,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=525'>526</a>\u001b[0m conditional\u001b[39m=\u001b[39;49mconditional,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=526'>527</a>\u001b[0m average_by\u001b[39m=\u001b[39;49maverage_by,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=527'>528</a>\u001b[0m eps\u001b[39m=\u001b[39;49meps,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=528'>529</a>\u001b[0m slope\u001b[39m=\u001b[39;49mslope,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=529'>530</a>\u001b[0m use_hdi\u001b[39m=\u001b[39;49muse_hdi,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=530'>531</a>\u001b[0m prob\u001b[39m=\u001b[39;49mprob,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=531'>532</a>\u001b[0m transforms\u001b[39m=\u001b[39;49mtransforms,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=532'>533</a>\u001b[0m sample_new_groups\u001b[39m=\u001b[39;49msample_new_groups,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=533'>534</a>\u001b[0m )\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=535'>536</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _plot_differences(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=536'>537</a>\u001b[0m model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=537'>538</a>\u001b[0m conditional_info\u001b[39m=\u001b[39mconditional_info,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=544'>545</a>\u001b[0m subplot_kwargs\u001b[39m=\u001b[39msubplot_kwargs,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=545'>546</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py:855\u001b[0m, in \u001b[0;36mslopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, use_hdi, prob, transforms, sample_new_groups)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=851'>852</a>\u001b[0m response \u001b[39m=\u001b[39m ResponseInfo(response_name, \u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m, lower_bound, upper_bound)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=852'>853</a>\u001b[0m response_transform \u001b[39m=\u001b[39m transforms\u001b[39m.\u001b[39mget(response_name, identity)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=854'>855</a>\u001b[0m slopes_data \u001b[39m=\u001b[39m create_differences_data(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=855'>856</a>\u001b[0m conditional_info, wrt_info, conditional_info\u001b[39m.\u001b[39;49muser_passed, effect_type\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=856'>857</a>\u001b[0m )\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=857'>858</a>\u001b[0m idata \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mpredict(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=858'>859</a>\u001b[0m idata, data\u001b[39m=\u001b[39mslopes_data, sample_new_groups\u001b[39m=\u001b[39msample_new_groups, inplace\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=859'>860</a>\u001b[0m )\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=861'>862</a>\u001b[0m \u001b[39m# returns empty array if model predictions do not have multiple dimensions\u001b[39;00m\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py:198\u001b[0m, in \u001b[0;36mcreate_differences_data\u001b[0;34m(condition_info, variable_info, user_passed, kind)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=194'>195</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m condition_info\u001b[39m.\u001b[39mcovariates:\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=195'>196</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _differences_unit_level(variable_info, kind)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=197'>198</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _grid_level(condition_info, variable_info, user_passed, kind)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py:124\u001b[0m, in \u001b[0;36m_grid_level\u001b[0;34m(condition_info, variable_info, user_passed, kind)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=118'>119</a>\u001b[0m data_grid \u001b[39m=\u001b[39m enforce_dtypes(condition_info\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mdata, data_grid, except_col)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=120'>121</a>\u001b[0m \u001b[39m# After computing default values, fractional values may have been computed.\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=121'>122</a>\u001b[0m \u001b[39m# Enforcing the dtype of \"int\" may create duplicate rows as it will round\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=122'>123</a>\u001b[0m \u001b[39m# the fractional values.\u001b[39;00m\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=123'>124</a>\u001b[0m data_grid \u001b[39m=\u001b[39m data_grid\u001b[39m.\u001b[39;49mdrop_duplicates()\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=125'>126</a>\u001b[0m \u001b[39mreturn\u001b[39;00m data_grid\u001b[39m.\u001b[39mreset_index(drop\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6802\u001b[0m, in \u001b[0;36mDataFrame.drop_duplicates\u001b[0;34m(self, subset, keep, inplace, ignore_index)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6798'>6799</a>\u001b[0m inplace \u001b[39m=\u001b[39m validate_bool_kwarg(inplace, \u001b[39m\"\u001b[39m\u001b[39minplace\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6799'>6800</a>\u001b[0m ignore_index \u001b[39m=\u001b[39m validate_bool_kwarg(ignore_index, \u001b[39m\"\u001b[39m\u001b[39mignore_index\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6801'>6802</a>\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m[\u001b[39m-\u001b[39m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mduplicated(subset, keep\u001b[39m=\u001b[39;49mkeep)]\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6802'>6803</a>\u001b[0m \u001b[39mif\u001b[39;00m ignore_index:\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6803'>6804</a>\u001b[0m result\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m default_index(\u001b[39mlen\u001b[39m(result))\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6942\u001b[0m, in \u001b[0;36mDataFrame.duplicated\u001b[0;34m(self, subset, keep)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6939'>6940</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6940'>6941</a>\u001b[0m vals \u001b[39m=\u001b[39m (col\u001b[39m.\u001b[39mvalues \u001b[39mfor\u001b[39;00m name, col \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m subset)\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6941'>6942</a>\u001b[0m labels, shape \u001b[39m=\u001b[39m \u001b[39mmap\u001b[39m(\u001b[39mlist\u001b[39m, \u001b[39mzip\u001b[39;49m(\u001b[39m*\u001b[39;49m\u001b[39mmap\u001b[39;49m(f, vals)))\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6943'>6944</a>\u001b[0m ids \u001b[39m=\u001b[39m get_group_index(labels, \u001b[39mtuple\u001b[39m(shape), sort\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, xnull\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6944'>6945</a>\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_constructor_sliced(duplicated(ids, keep), index\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindex)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6910\u001b[0m, in \u001b[0;36mDataFrame.duplicated.<locals>.f\u001b[0;34m(vals)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6908'>6909</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mf\u001b[39m(vals) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[np\u001b[39m.\u001b[39mndarray, \u001b[39mint\u001b[39m]:\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6909'>6910</a>\u001b[0m labels, shape \u001b[39m=\u001b[39m algorithms\u001b[39m.\u001b[39;49mfactorize(vals, size_hint\u001b[39m=\u001b[39;49m\u001b[39mlen\u001b[39;49m(\u001b[39mself\u001b[39;49m))\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6910'>6911</a>\u001b[0m \u001b[39mreturn\u001b[39;00m labels\u001b[39m.\u001b[39mastype(\u001b[39m\"\u001b[39m\u001b[39mi8\u001b[39m\u001b[39m\"\u001b[39m, copy\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m), \u001b[39mlen\u001b[39m(shape)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py:795\u001b[0m, in \u001b[0;36mfactorize\u001b[0;34m(values, sort, use_na_sentinel, size_hint)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=791'>792</a>\u001b[0m \u001b[39m# Don't modify (potentially user-provided) array\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=792'>793</a>\u001b[0m values \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mwhere(null_mask, na_value, values)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=794'>795</a>\u001b[0m codes, uniques \u001b[39m=\u001b[39m factorize_array(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=795'>796</a>\u001b[0m values,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=796'>797</a>\u001b[0m use_na_sentinel\u001b[39m=\u001b[39;49muse_na_sentinel,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=797'>798</a>\u001b[0m size_hint\u001b[39m=\u001b[39;49msize_hint,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=798'>799</a>\u001b[0m )\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=800'>801</a>\u001b[0m \u001b[39mif\u001b[39;00m sort \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(uniques) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=801'>802</a>\u001b[0m uniques, codes \u001b[39m=\u001b[39m safe_sort(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=802'>803</a>\u001b[0m uniques,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=803'>804</a>\u001b[0m codes,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=806'>807</a>\u001b[0m verify\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=807'>808</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py:595\u001b[0m, in \u001b[0;36mfactorize_array\u001b[0;34m(values, use_na_sentinel, size_hint, na_value, mask)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=591'>592</a>\u001b[0m hash_klass, values \u001b[39m=\u001b[39m _get_hashtable_algo(values)\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=593'>594</a>\u001b[0m table \u001b[39m=\u001b[39m hash_klass(size_hint \u001b[39mor\u001b[39;00m \u001b[39mlen\u001b[39m(values))\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=594'>595</a>\u001b[0m uniques, codes \u001b[39m=\u001b[39m table\u001b[39m.\u001b[39;49mfactorize(\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=595'>596</a>\u001b[0m values,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=596'>597</a>\u001b[0m na_sentinel\u001b[39m=\u001b[39;49m\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=597'>598</a>\u001b[0m na_value\u001b[39m=\u001b[39;49mna_value,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=598'>599</a>\u001b[0m mask\u001b[39m=\u001b[39;49mmask,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=599'>600</a>\u001b[0m ignore_na\u001b[39m=\u001b[39;49muse_na_sentinel,\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=600'>601</a>\u001b[0m )\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=602'>603</a>\u001b[0m \u001b[39m# re-cast e.g. i8->dt64/td64, uint8->bool\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=603'>604</a>\u001b[0m uniques \u001b[39m=\u001b[39m _reconstruct_data(uniques, original\u001b[39m.\u001b[39mdtype, original)\n\nFile \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7281\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.factorize\u001b[0;34m()\u001b[0m\n\nFile \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7195\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable._unique\u001b[0;34m()\u001b[0m\n\n\u001b[0;31mTypeError\u001b[0m: unhashable type: 'numpy.ndarray'"
}
It works if we convert object to a list
fig, ax = bmb.interpret.plot_slopes(
spline_model,
spline_idata,
wrt="Edu" ,
conditional={
"Income": list(np.linspace(0.1, 1, 10))
}
)
But gives a kind of nonsensical plot
Ok, starting to make progress. I think i just wasn't using the API correctly.
If i pass in a dictionary for the wrt variable in the slopes calculation than i start to see variation across the income variable:
bmb.interpret.plot_slopes(
spline_model,
spline_idata,
wrt={"Edu": [.8, .9]} ,
conditional={
"Income": list(np.linspace(0.1, 1, 10)),
"Health": list(np.linspace(0.1, 1, 10))
}
)
I guess i'm just not entirely sure what plot slopes will do with a spline based model?
Thanks a lot @tomicapretto and @NathanielF for your comments.
@GStechschulte I see in the code @NathanielF shared he's setting wrt="Overall", when "Overall" is the response variable. I think that's not the intended usage of the argument, right?
That's correct. For all of the interpret functions, we are typically interested in how the response variable changes given some values for the covariates; generally speaking.
@NathanielF are you trying to use plot_slopes() or plot_predictions()? Just notice I'm not an expert on plot_slopes() (Gabriel knows much more about this than me). But I'm guessing what you want is plot_predictions().
Yeah, it seems plot_predictions
is what you want. Though I need to read your blog post first to confirm. This isn't to say you can't use plot_slopes
. Indeed, slopes should work.
I guess i'm just not entirely sure what plot slopes will do with a spline based model?
Plot slopes isn't doing anything per say with the model. What it is doing is saying, "okay, cool. You have defined and fit a model on some data. Now, lets generate some "new data", feed that data to the model, and see what the predictions are". I think Richard McElreath uses the term "seeing through the eyes of the model".
Under the hood, slopes is assembling a pairwise grid of values that allows you to compare the predictions of the response variable made with respect to wrt
a small change in a predictor conditional on other covariates specified in the model. We use finite differences to compute the derivative.
Regarding your two comments
It works if we convert object to a list
Not sure why this is required in your example. When I run our documentation examples for slopes where a np.ndarray
is passed as the value, it works.
If i pass in a dictionary for the wrt variable in the slopes calculation than i start to see variation across the income variable
Same as above. You shouldn't have to pass a dict.
I am looking into the last two comments. :)
@NathanielF regarding your comment:
It works if we convert object to a list
I don't have this problem. Cherry-picking @tomicapretto hot fix commit I can do
fig, ax = bmb.interpret.plot_slopes(
spline_model,
spline_idata,
wrt={"Edu": [.8, .9]} ,
conditional={
"Income": np.linspace(0.1, 1, 10), # No need to convert array to list
"Health": np.linspace(0.1, 1, 3) # No need to convert array to list
}
)
I also tested out a couple other parameter data type variations (array, list, and float)
fig, ax = bmb.interpret.plot_predictions(
spline_model,
spline_idata,
{
"Edu": np.linspace(0.7, 1, num=100),
"Income": [0.1, 0.3, 0.5, 0.7, 0.9],
"Health": 0.893
},
subplot_kwargs={"main": "Edu", "group": 'Income', "panel": "Income"},
fig_kwargs={"figsize": (12, 4)},
legend=False
)
plt.tight_layout()
and everything is working as expected. It is interesting to see how the predictions for Overall
change significantly as Income
is $> 0.5$.
Very good. Not sure about the array list issue, but I was also able to get plot predictions working to show similar effects of incomes. The response just blows up at extreme realisations - which threw me.
Thanks for your work on this!
Unrelated, but I wondered if you had thoughts on the hierarchical spline model I've implemented in that blog post? Or how something similar could be handled in bambi? Especially the predicting on new years/groups step.
@NathanielF anytime! I hope to read your blog in depth over the coming days 👍🏼
Unrelated, but I wondered if you had thoughts on the hierarchical spline model I've implemented in that blog post? Or how something similar could be handled in bambi? Especially the predicting on new years/groups step.
Haven't looked into it in depth, but I'm very curious about it so I will be reading it and I'll provide feedback here if you want
That would be amazing @tomicapretto . I think all the mechanics are there in the blog post. I need to finesse the writing a bit.
But fyi the basic theme I want to interrogate is that unpooled and pooled spline models seem brittle for extrapolation, but a hierarchical model seems to offer a way to "abstract" over the individual loss curves and pull out something like the generic shape of the loss curves to predict on new years....
I'm thinking this is a good way to balance the overfit tendency of spline models....?
Just FYI @tomicapretto and @GStechschulte I've tightened the writing here and showed how to do sampling for "new groups" over my hierarchical spline model for the insurance curves. Additionally i've shown how adding hierarchies to a multiple regression spline modelling can improve the extrapolation to new groups
@NathanielF thanks for the update! I'm reading through it right now (it's been in my mind, but I couldn't find time time)
No worries. Appreciate any feedback you might have.
I'm possibly doing something a bit naive, but i just wondered if there was interpretation functions available for spline based models like you get with mgcv in R.
This breaks with the following error:
The last command breaks with the following error: