bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.09k stars 124 forks source link

Potential Issue re: Spline models and Bambi's interpret module #796

Open NathanielF opened 7 months ago

NathanielF commented 7 months ago

I'm possibly doing something a bit naive, but i just wondered if there was interpretation functions available for spline based models like you get with mgcv in R.

import pandas as pd
import bambi as bmb

pisa_df = pd.read_csv("https://raw.githubusercontent.com/m-clark/generalized-additive-models/master/data/pisasci2006.csv")

formula = "Overall ~ Income + Edu + Health"
base_model = bmb.Model(formula, pisa_df, dropna=True)

knots_income = np.linspace(np.min(pisa_df['Income']), np.max(pisa_df['Income']), 5+2)[1:-1]

knots_edu = np.linspace(np.min(pisa_df['Edu']), np.max(pisa_df['Edu']), 5+2)[1:-1]

knots_health = np.linspace(np.min(pisa_df['Health']), np.max(pisa_df['Health']), 5+2)[1:-1]

formula_spline = """Overall ~ bs(Income, degree=3, knots=knots_income) + bs(Edu, degree=3, knots=knots_edu) + bs(Health, degree=3, knots=knots_health)"""

spline_model = bmb.Model(formula_spline, pisa_df, dropna=True)

base_idata = base_model.fit(random_seed=random_seed, idata_kwargs={"log_likelihood": True})
spline_idata = spline_model.fit(random_seed=random_seed, idata_kwargs={"log_likelihood": True}, target_accept=.95)
base_model.predict(base_idata, kind='pps')
spline_model.predict(spline_idata, kind='pps')

##### Now try plotting with interpret

fig, ax = bmb.interpret.plot_slopes(
    spline_model,
    spline_idata,
    wrt="Overall",
    conditional={
        "bs(Edu, degree=3, knots=knots_edu)": np.linspace(0.6, 1, 50),
    },
)

This breaks with the following error:

{
    "name": "UnboundLocalError",
    "message": "cannot access local variable 'values' where it is not associated with a value",
    "stack": "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mUnboundLocalError\u001b[0m                         Traceback (most recent call last)\n\u001b[1;32m/Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd\u001b[0m in \u001b[0;36mline 1\n\u001b[0;32m----> <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=923'>924</a>\u001b[0m fig, ax \u001b[39m=\u001b[39m bmb\u001b[39m.\u001b[39;49minterpret\u001b[39m.\u001b[39;49mplot_slopes(\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=924'>925</a>\u001b[0m     spline_model,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=925'>926</a>\u001b[0m     spline_idata,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=926'>927</a>\u001b[0m     wrt\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mOverall\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=927'>928</a>\u001b[0m     conditional\u001b[39m=\u001b[39;49m{\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=928'>929</a>\u001b[0m         \u001b[39m\"\u001b[39;49m\u001b[39mbs(Edu, degree=3, knots=knots_edu)\u001b[39;49m\u001b[39m\"\u001b[39;49m: np\u001b[39m.\u001b[39;49mlinspace(\u001b[39m0.6\u001b[39;49m, \u001b[39m1\u001b[39;49m, \u001b[39m50\u001b[39;49m),\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=929'>930</a>\u001b[0m     },\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=930'>931</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py:522\u001b[0m, in \u001b[0;36mplot_slopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, sample_new_groups, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=517'>518</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39mslope\u001b[39m\u001b[39m'\u001b[39m\u001b[39m must be one of (\u001b[39m\u001b[39m'\u001b[39m\u001b[39mdydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39mdyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=519'>520</a>\u001b[0m conditional_info \u001b[39m=\u001b[39m ConditionalInfo(model, conditional)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=521'>522</a>\u001b[0m slopes_summary \u001b[39m=\u001b[39m slopes(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=522'>523</a>\u001b[0m     model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=523'>524</a>\u001b[0m     idata\u001b[39m=\u001b[39;49midata,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=524'>525</a>\u001b[0m     wrt\u001b[39m=\u001b[39;49mwrt,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=525'>526</a>\u001b[0m     conditional\u001b[39m=\u001b[39;49mconditional,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=526'>527</a>\u001b[0m     average_by\u001b[39m=\u001b[39;49maverage_by,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=527'>528</a>\u001b[0m     eps\u001b[39m=\u001b[39;49meps,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=528'>529</a>\u001b[0m     slope\u001b[39m=\u001b[39;49mslope,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=529'>530</a>\u001b[0m     use_hdi\u001b[39m=\u001b[39;49muse_hdi,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=530'>531</a>\u001b[0m     prob\u001b[39m=\u001b[39;49mprob,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=531'>532</a>\u001b[0m     transforms\u001b[39m=\u001b[39;49mtransforms,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=532'>533</a>\u001b[0m     sample_new_groups\u001b[39m=\u001b[39;49msample_new_groups,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=533'>534</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=535'>536</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _plot_differences(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=536'>537</a>\u001b[0m     model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=537'>538</a>\u001b[0m     conditional_info\u001b[39m=\u001b[39mconditional_info,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=544'>545</a>\u001b[0m     subplot_kwargs\u001b[39m=\u001b[39msubplot_kwargs,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=545'>546</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py:844\u001b[0m, in \u001b[0;36mslopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, use_hdi, prob, transforms, sample_new_groups)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=841'>842</a>\u001b[0m     effect_type \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcomparisons\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=842'>843</a>\u001b[0m     eps \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=843'>844</a>\u001b[0m wrt_info \u001b[39m=\u001b[39m VariableInfo(model, wrt, effect_type, grid, eps)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=845'>846</a>\u001b[0m lower_bound \u001b[39m=\u001b[39m \u001b[39mround\u001b[39m((\u001b[39m1\u001b[39m \u001b[39m-\u001b[39m prob) \u001b[39m/\u001b[39m \u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=846'>847</a>\u001b[0m upper_bound \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m-\u001b[39m lower_bound\n\nFile \u001b[0;32m<string>:9\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, model, variable, kind, grid, eps, user_passed)\u001b[0m\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py:79\u001b[0m, in \u001b[0;36mVariableInfo.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=76'>77</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvariable, \u001b[39mlist\u001b[39m):\n\u001b[1;32m     <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=77'>78</a>\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mname \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvariable)\n\u001b[0;32m---> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=78'>79</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalues \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mset_default_variable_values()\n\u001b[1;32m     <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=79'>80</a>\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvariable, (\u001b[39mlist\u001b[39m, \u001b[39mdict\u001b[39m, \u001b[39mstr\u001b[39m)):\n\u001b[1;32m     <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=80'>81</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m`variable` must be a list, dict, or string\u001b[39m\u001b[39m\"\u001b[39m)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py:130\u001b[0m, in \u001b[0;36mVariableInfo.set_default_variable_values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=126'>127</a>\u001b[0m                     \u001b[39melif\u001b[39;00m component\u001b[39m.\u001b[39mkind \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcategoric\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=127'>128</a>\u001b[0m                         values \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39munique(predictor_data)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/utils.py?line=129'>130</a>\u001b[0m \u001b[39mreturn\u001b[39;00m values\n\n\u001b[0;31mUnboundLocalError\u001b[0m: cannot access local variable 'values' where it is not associated with a value"
}

The last command breaks with the following error:

NathanielF commented 7 months ago

I imagine there is quite a bit of work in adapting the interpret function to account for the static spline bases but i'm just wondering if there is something obvious i'm missing?

GStechschulte commented 7 months ago

Hey @NathanielF thanks for raising the issue. This should work. What version of Bambi are you using?

A couple months ago we added enhancements of how the interpret package parses hsgp and basis spline terms in the Bambi model. It may be that you need to install from source to get these latest changes.

Nonetheless, I will attempt to reproduce your example and get back with you.

NathanielF commented 7 months ago

Thanks @GStechschulte I'm using a 0.13:

image

I'd be keen to discover if i can use Bambi's native interpret function.

For the moment i've hacked to a routine to extract the components:


Bincome = spline_model.response_component.design.common['bs(Income, degree=3, knots=knots_income)']

income_coefs = az.extract(spline_idata['posterior']['bs(Income, degree=3, knots=knots_income)'])['bs(Income, degree=3, knots=knots_income)']

Bedu = spline_model.response_component.design.common['bs(Edu, degree=3, knots=knots_edu)']

edu_coefs = az.extract(spline_idata['posterior']['bs(Edu, degree=3, knots=knots_edu)'])['bs(Edu, degree=3, knots=knots_edu)']

Bhealth = spline_model.response_component.design.common['bs(Health, degree=3, knots=knots_health)']

health_coefs = az.extract(spline_idata['posterior']['bs(Health, degree=3, knots=knots_health)'])['bs(Health, degree=3, knots=knots_health)']

income = np.dot(Bincome, income_coefs).T 
edu = np.dot(Bedu, edu_coefs).T
health = np.dot(Bhealth, health_coefs).T

intercept = az.extract(spline_idata['posterior']['Intercept'])['Intercept'].values

fig, ax = plt.subplots(figsize=(9, 7))
for i in range(100):
    if i == 1:
        ax.plot(income[i], label='Income Component', color='red')
        ax.plot(edu[i], label='Edu Component', color='blue')
        ax.plot(health[i], label='Health Component', color='darkgreen')
        ax.plot(intercept[i] + income[i] + edu[i] + health[i], label='Combined Components', color='purple')
    else: 
        ax.plot(income[i], alpha=0.1, color='red')
        ax.plot(edu[i], alpha=0.1, color='blue')
        ax.plot(health[i], alpha=0.1, color='darkgreen')
        ax.plot(intercept[i] + income[i] + edu[i] + health[i], color='purple', alpha=0.3)

ax.scatter(range(len(spline_idata['observed_data']['Overall'])), spline_idata['observed_data']['Overall'], label='Observed', color='grey', s=56, ec='black')
ax.set_title("Additive Spline Components", fontsize=20)
ax.legend();
ax.set_xticklabels(pisa_df.dropna(axis=0).reset_index()['Country'])

which gives me this plot:

image

Work in Progress here: https://nathanielf.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.html

Diving into both splines, hierarchical spline modelling and HSGP approximations using Bambi. Would be great if i could "show-off" the full functionality.

tomicapretto commented 7 months ago

@NathanielF @GStechschulte I'm working on this, I think the issue is where looking up some variable in the data frame when that should not be the case.

tomicapretto commented 7 months ago

@GStechschulte I see in the code @NathanielF shared he's setting wrt="Overall", when "Overall" is the response variable. I think that's not the intended usage of the argument, right?

@NathanielF are you trying to use plot_slopes() or plot_predictions()? Just notice I'm not an expert on plot_slopes() (Gabriel knows much more about this than me). But I'm guessing what you want is plot_predictions().

NathanielF commented 7 months ago

Ah ok. That maybe it. I will have another look!

tomicapretto commented 7 months ago

It's still broken in the latest release and the main branch, I'm about to push a branch with a quick fix and open an issue explaining the problem. Thanks @NathanielF for raising this, it allowed us to find a bug :)

NathanielF commented 7 months ago

Nice! Happy to help

tomicapretto commented 7 months ago

Is this something close to what you're looking for?

fig, ax = bmb.interpret.plot_predictions(
    spline_model, spline_idata, "Edu"
)

image

fig, ax = bmb.interpret.plot_predictions(
    spline_model, spline_idata, {"Edu": np.linspace(0.7, 1, num=100)}
)

image

tomicapretto commented 7 months ago

@NathanielF this is the branch https://github.com/tomicapretto/bambi/tree/spline_interpret and this is a notebook with the code https://github.com/tomicapretto/bambi/blob/spline_interpret/spline_prediction.ipynb (I'll delete it before merging into Bambi's main)

The main point is that in conditional we pass name of variables, not name of terms. That is because variables can play a role in many terms. For example, 1 + x + x:z. The same happens when you use a transformation log(x), exp(x) or bs(x). We only need to pass the name of the covariate, not the name of the term.

NathanielF commented 7 months ago

The plot predictions code works for me too. But just out of curiosity i tried the plot slopes without the Overall variable.

fig, ax = bmb.interpret.plot_slopes(
    spline_model,
    spline_idata,
    wrt="Health",
    conditional={
        "Income": np.linspace(0.1, 0.5, 10),
    }
)

I've change these to two covariates and i'm now getting a different error:

{
    "name": "TypeError",
    "message": "unhashable type: 'numpy.ndarray'",
    "stack": "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)\n\u001b[1;32m/Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd\u001b[0m in \u001b[0;36mline 1\n\u001b[0;32m----> <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=962'>963</a>\u001b[0m fig, ax \u001b[39m=\u001b[39m bmb\u001b[39m.\u001b[39;49minterpret\u001b[39m.\u001b[39;49mplot_slopes(\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=963'>964</a>\u001b[0m     spline_model,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=964'>965</a>\u001b[0m     spline_idata,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=965'>966</a>\u001b[0m     wrt\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mHealth\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=966'>967</a>\u001b[0m     conditional\u001b[39m=\u001b[39;49m{\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=967'>968</a>\u001b[0m         \u001b[39m\"\u001b[39;49m\u001b[39mIncome\u001b[39;49m\u001b[39m\"\u001b[39;49m: np\u001b[39m.\u001b[39;49mlinspace(\u001b[39m0.1\u001b[39;49m, \u001b[39m0.5\u001b[39;49m, \u001b[39m10\u001b[39;49m),\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=968'>969</a>\u001b[0m     }\n\u001b[1;32m      <a href='file:///Users/nathanielforde/Documents/Github/NathanielF.github.io/posts/post-with-code/GAMs_and_GPs/gams_and_gps.qmd?line=969'>970</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py:522\u001b[0m, in \u001b[0;36mplot_slopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, sample_new_groups, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=517'>518</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39mslope\u001b[39m\u001b[39m'\u001b[39m\u001b[39m must be one of (\u001b[39m\u001b[39m'\u001b[39m\u001b[39mdydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39mdyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meyex\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39meydx\u001b[39m\u001b[39m'\u001b[39m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=519'>520</a>\u001b[0m conditional_info \u001b[39m=\u001b[39m ConditionalInfo(model, conditional)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=521'>522</a>\u001b[0m slopes_summary \u001b[39m=\u001b[39m slopes(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=522'>523</a>\u001b[0m     model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=523'>524</a>\u001b[0m     idata\u001b[39m=\u001b[39;49midata,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=524'>525</a>\u001b[0m     wrt\u001b[39m=\u001b[39;49mwrt,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=525'>526</a>\u001b[0m     conditional\u001b[39m=\u001b[39;49mconditional,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=526'>527</a>\u001b[0m     average_by\u001b[39m=\u001b[39;49maverage_by,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=527'>528</a>\u001b[0m     eps\u001b[39m=\u001b[39;49meps,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=528'>529</a>\u001b[0m     slope\u001b[39m=\u001b[39;49mslope,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=529'>530</a>\u001b[0m     use_hdi\u001b[39m=\u001b[39;49muse_hdi,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=530'>531</a>\u001b[0m     prob\u001b[39m=\u001b[39;49mprob,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=531'>532</a>\u001b[0m     transforms\u001b[39m=\u001b[39;49mtransforms,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=532'>533</a>\u001b[0m     sample_new_groups\u001b[39m=\u001b[39;49msample_new_groups,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=533'>534</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=535'>536</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _plot_differences(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=536'>537</a>\u001b[0m     model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=537'>538</a>\u001b[0m     conditional_info\u001b[39m=\u001b[39mconditional_info,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=544'>545</a>\u001b[0m     subplot_kwargs\u001b[39m=\u001b[39msubplot_kwargs,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/plotting.py?line=545'>546</a>\u001b[0m )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py:855\u001b[0m, in \u001b[0;36mslopes\u001b[0;34m(model, idata, wrt, conditional, average_by, eps, slope, use_hdi, prob, transforms, sample_new_groups)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=851'>852</a>\u001b[0m response \u001b[39m=\u001b[39m ResponseInfo(response_name, \u001b[39m\"\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m\"\u001b[39m, lower_bound, upper_bound)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=852'>853</a>\u001b[0m response_transform \u001b[39m=\u001b[39m transforms\u001b[39m.\u001b[39mget(response_name, identity)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=854'>855</a>\u001b[0m slopes_data \u001b[39m=\u001b[39m create_differences_data(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=855'>856</a>\u001b[0m     conditional_info, wrt_info, conditional_info\u001b[39m.\u001b[39;49muser_passed, effect_type\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=856'>857</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=857'>858</a>\u001b[0m idata \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mpredict(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=858'>859</a>\u001b[0m     idata, data\u001b[39m=\u001b[39mslopes_data, sample_new_groups\u001b[39m=\u001b[39msample_new_groups, inplace\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=859'>860</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/effects.py?line=861'>862</a>\u001b[0m \u001b[39m# returns empty array if model predictions do not have multiple dimensions\u001b[39;00m\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py:198\u001b[0m, in \u001b[0;36mcreate_differences_data\u001b[0;34m(condition_info, variable_info, user_passed, kind)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=194'>195</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m condition_info\u001b[39m.\u001b[39mcovariates:\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=195'>196</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m _differences_unit_level(variable_info, kind)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=197'>198</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _grid_level(condition_info, variable_info, user_passed, kind)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py:124\u001b[0m, in \u001b[0;36m_grid_level\u001b[0;34m(condition_info, variable_info, user_passed, kind)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=118'>119</a>\u001b[0m data_grid \u001b[39m=\u001b[39m enforce_dtypes(condition_info\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mdata, data_grid, except_col)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=120'>121</a>\u001b[0m \u001b[39m# After computing default values, fractional values may have been computed.\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=121'>122</a>\u001b[0m \u001b[39m# Enforcing the dtype of \"int\" may create duplicate rows as it will round\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=122'>123</a>\u001b[0m \u001b[39m# the fractional values.\u001b[39;00m\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=123'>124</a>\u001b[0m data_grid \u001b[39m=\u001b[39m data_grid\u001b[39m.\u001b[39;49mdrop_duplicates()\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/bambi/interpret/create_data.py?line=125'>126</a>\u001b[0m \u001b[39mreturn\u001b[39;00m data_grid\u001b[39m.\u001b[39mreset_index(drop\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6802\u001b[0m, in \u001b[0;36mDataFrame.drop_duplicates\u001b[0;34m(self, subset, keep, inplace, ignore_index)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6798'>6799</a>\u001b[0m inplace \u001b[39m=\u001b[39m validate_bool_kwarg(inplace, \u001b[39m\"\u001b[39m\u001b[39minplace\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6799'>6800</a>\u001b[0m ignore_index \u001b[39m=\u001b[39m validate_bool_kwarg(ignore_index, \u001b[39m\"\u001b[39m\u001b[39mignore_index\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6801'>6802</a>\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m[\u001b[39m-\u001b[39m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mduplicated(subset, keep\u001b[39m=\u001b[39;49mkeep)]\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6802'>6803</a>\u001b[0m \u001b[39mif\u001b[39;00m ignore_index:\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6803'>6804</a>\u001b[0m     result\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m default_index(\u001b[39mlen\u001b[39m(result))\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6942\u001b[0m, in \u001b[0;36mDataFrame.duplicated\u001b[0;34m(self, subset, keep)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6939'>6940</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6940'>6941</a>\u001b[0m     vals \u001b[39m=\u001b[39m (col\u001b[39m.\u001b[39mvalues \u001b[39mfor\u001b[39;00m name, col \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m subset)\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6941'>6942</a>\u001b[0m     labels, shape \u001b[39m=\u001b[39m \u001b[39mmap\u001b[39m(\u001b[39mlist\u001b[39m, \u001b[39mzip\u001b[39;49m(\u001b[39m*\u001b[39;49m\u001b[39mmap\u001b[39;49m(f, vals)))\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6943'>6944</a>\u001b[0m     ids \u001b[39m=\u001b[39m get_group_index(labels, \u001b[39mtuple\u001b[39m(shape), sort\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, xnull\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6944'>6945</a>\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_constructor_sliced(duplicated(ids, keep), index\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindex)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py:6910\u001b[0m, in \u001b[0;36mDataFrame.duplicated.<locals>.f\u001b[0;34m(vals)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6908'>6909</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mf\u001b[39m(vals) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[np\u001b[39m.\u001b[39mndarray, \u001b[39mint\u001b[39m]:\n\u001b[0;32m-> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6909'>6910</a>\u001b[0m     labels, shape \u001b[39m=\u001b[39m algorithms\u001b[39m.\u001b[39;49mfactorize(vals, size_hint\u001b[39m=\u001b[39;49m\u001b[39mlen\u001b[39;49m(\u001b[39mself\u001b[39;49m))\n\u001b[1;32m   <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/frame.py?line=6910'>6911</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m labels\u001b[39m.\u001b[39mastype(\u001b[39m\"\u001b[39m\u001b[39mi8\u001b[39m\u001b[39m\"\u001b[39m, copy\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m), \u001b[39mlen\u001b[39m(shape)\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py:795\u001b[0m, in \u001b[0;36mfactorize\u001b[0;34m(values, sort, use_na_sentinel, size_hint)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=791'>792</a>\u001b[0m             \u001b[39m# Don't modify (potentially user-provided) array\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=792'>793</a>\u001b[0m             values \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mwhere(null_mask, na_value, values)\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=794'>795</a>\u001b[0m     codes, uniques \u001b[39m=\u001b[39m factorize_array(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=795'>796</a>\u001b[0m         values,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=796'>797</a>\u001b[0m         use_na_sentinel\u001b[39m=\u001b[39;49muse_na_sentinel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=797'>798</a>\u001b[0m         size_hint\u001b[39m=\u001b[39;49msize_hint,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=798'>799</a>\u001b[0m     )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=800'>801</a>\u001b[0m \u001b[39mif\u001b[39;00m sort \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(uniques) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=801'>802</a>\u001b[0m     uniques, codes \u001b[39m=\u001b[39m safe_sort(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=802'>803</a>\u001b[0m         uniques,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=803'>804</a>\u001b[0m         codes,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=806'>807</a>\u001b[0m         verify\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=807'>808</a>\u001b[0m     )\n\nFile \u001b[0;32m~/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py:595\u001b[0m, in \u001b[0;36mfactorize_array\u001b[0;34m(values, use_na_sentinel, size_hint, na_value, mask)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=591'>592</a>\u001b[0m hash_klass, values \u001b[39m=\u001b[39m _get_hashtable_algo(values)\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=593'>594</a>\u001b[0m table \u001b[39m=\u001b[39m hash_klass(size_hint \u001b[39mor\u001b[39;00m \u001b[39mlen\u001b[39m(values))\n\u001b[0;32m--> <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=594'>595</a>\u001b[0m uniques, codes \u001b[39m=\u001b[39m table\u001b[39m.\u001b[39;49mfactorize(\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=595'>596</a>\u001b[0m     values,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=596'>597</a>\u001b[0m     na_sentinel\u001b[39m=\u001b[39;49m\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=597'>598</a>\u001b[0m     na_value\u001b[39m=\u001b[39;49mna_value,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=598'>599</a>\u001b[0m     mask\u001b[39m=\u001b[39;49mmask,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=599'>600</a>\u001b[0m     ignore_na\u001b[39m=\u001b[39;49muse_na_sentinel,\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=600'>601</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=602'>603</a>\u001b[0m \u001b[39m# re-cast e.g. i8->dt64/td64, uint8->bool\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nathanielforde/mambaforge/envs/pymc_causal/lib/python3.11/site-packages/pandas/core/algorithms.py?line=603'>604</a>\u001b[0m uniques \u001b[39m=\u001b[39m _reconstruct_data(uniques, original\u001b[39m.\u001b[39mdtype, original)\n\nFile \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7281\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.factorize\u001b[0;34m()\u001b[0m\n\nFile \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7195\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable._unique\u001b[0;34m()\u001b[0m\n\n\u001b[0;31mTypeError\u001b[0m: unhashable type: 'numpy.ndarray'"
}

Update

It works if we convert object to a list


fig, ax = bmb.interpret.plot_slopes(
    spline_model,
    spline_idata,
    wrt="Edu" ,
    conditional={
        "Income": list(np.linspace(0.1, 1, 10))
    }
)

But gives a kind of nonsensical plot

image
NathanielF commented 7 months ago

Ok, starting to make progress. I think i just wasn't using the API correctly.

If i pass in a dictionary for the wrt variable in the slopes calculation than i start to see variation across the income variable:


bmb.interpret.plot_slopes(
    spline_model,
    spline_idata,
    wrt={"Edu": [.8, .9]} ,
    conditional={
        "Income": list(np.linspace(0.1, 1, 10)),
        "Health": list(np.linspace(0.1, 1, 10))
    }
)
image

I guess i'm just not entirely sure what plot slopes will do with a spline based model?

GStechschulte commented 7 months ago

Thanks a lot @tomicapretto and @NathanielF for your comments.

@GStechschulte I see in the code @NathanielF shared he's setting wrt="Overall", when "Overall" is the response variable. I think that's not the intended usage of the argument, right?

That's correct. For all of the interpret functions, we are typically interested in how the response variable changes given some values for the covariates; generally speaking.

@NathanielF are you trying to use plot_slopes() or plot_predictions()? Just notice I'm not an expert on plot_slopes() (Gabriel knows much more about this than me). But I'm guessing what you want is plot_predictions().

Yeah, it seems plot_predictions is what you want. Though I need to read your blog post first to confirm. This isn't to say you can't use plot_slopes. Indeed, slopes should work.

I guess i'm just not entirely sure what plot slopes will do with a spline based model?

Plot slopes isn't doing anything per say with the model. What it is doing is saying, "okay, cool. You have defined and fit a model on some data. Now, lets generate some "new data", feed that data to the model, and see what the predictions are". I think Richard McElreath uses the term "seeing through the eyes of the model".

Under the hood, slopes is assembling a pairwise grid of values that allows you to compare the predictions of the response variable made with respect to wrt a small change in a predictor conditional on other covariates specified in the model. We use finite differences to compute the derivative.

Regarding your two comments

It works if we convert object to a list

Not sure why this is required in your example. When I run our documentation examples for slopes where a np.ndarray is passed as the value, it works.

If i pass in a dictionary for the wrt variable in the slopes calculation than i start to see variation across the income variable

Same as above. You shouldn't have to pass a dict.

I am looking into the last two comments. :)

GStechschulte commented 7 months ago

@NathanielF regarding your comment:

It works if we convert object to a list

I don't have this problem. Cherry-picking @tomicapretto hot fix commit I can do

fig, ax = bmb.interpret.plot_slopes(
    spline_model,
    spline_idata,
    wrt={"Edu": [.8, .9]} ,
    conditional={
        "Income": np.linspace(0.1, 1, 10), # No need to convert array to list
        "Health": np.linspace(0.1, 1, 3)     # No need to convert array to list
    }
)

image

I also tested out a couple other parameter data type variations (array, list, and float)

fig, ax = bmb.interpret.plot_predictions(
    spline_model, 
    spline_idata, 
    {
        "Edu": np.linspace(0.7, 1, num=100),
        "Income": [0.1, 0.3, 0.5, 0.7, 0.9],
        "Health": 0.893
    },
    subplot_kwargs={"main": "Edu", "group": 'Income', "panel": "Income"},
    fig_kwargs={"figsize": (12, 4)},
    legend=False
)
plt.tight_layout()

image

and everything is working as expected. It is interesting to see how the predictions for Overall change significantly as Income is $> 0.5$.

NathanielF commented 7 months ago

Very good. Not sure about the array list issue, but I was also able to get plot predictions working to show similar effects of incomes. The response just blows up at extreme realisations - which threw me.

Thanks for your work on this!

Unrelated, but I wondered if you had thoughts on the hierarchical spline model I've implemented in that blog post? Or how something similar could be handled in bambi? Especially the predicting on new years/groups step.

GStechschulte commented 7 months ago

@NathanielF anytime! I hope to read your blog in depth over the coming days 👍🏼

tomicapretto commented 7 months ago

Unrelated, but I wondered if you had thoughts on the hierarchical spline model I've implemented in that blog post? Or how something similar could be handled in bambi? Especially the predicting on new years/groups step.

Haven't looked into it in depth, but I'm very curious about it so I will be reading it and I'll provide feedback here if you want

NathanielF commented 7 months ago

That would be amazing @tomicapretto . I think all the mechanics are there in the blog post. I need to finesse the writing a bit.

But fyi the basic theme I want to interrogate is that unpooled and pooled spline models seem brittle for extrapolation, but a hierarchical model seems to offer a way to "abstract" over the individual loss curves and pull out something like the generic shape of the loss curves to predict on new years....

I'm thinking this is a good way to balance the overfit tendency of spline models....?

NathanielF commented 7 months ago

Just FYI @tomicapretto and @GStechschulte I've tightened the writing here and showed how to do sampling for "new groups" over my hierarchical spline model for the insurance curves. Additionally i've shown how adding hierarchies to a multiple regression spline modelling can improve the extrapolation to new groups

image
tomicapretto commented 7 months ago

@NathanielF thanks for the update! I'm reading through it right now (it's been in my mind, but I couldn't find time time)

NathanielF commented 7 months ago

No worries. Appreciate any feedback you might have.