vincentarelbundock / pymarginaleffects

GNU General Public License v3.0
47 stars 8 forks source link

`get_variables_names()` in class `ModelStatsmodels` does not return all variables which causes errors #91

Closed RoelVerbelen closed 4 months ago

RoelVerbelen commented 4 months ago

As far as I'm aware, there's no easy way to extract the names of the orginal columns used in a patsy formula, see these open tickets here and here. So you have to rely on regular expressions for now.

However the current code does not capture all complex scenarios which can occur in formulas, leading to errors for marginaleffects.

I try to illustrate that in the below code and suggest a potential alternative (which I'm currently relying on): detecting whether any of the data columns, surrounded by word boudaries, occurs in the model formula. It's still not perfect, as it can capture non model terms (such as Treatment, Good, minimum, df, constraints, center for the example below, if these exists as columns in the data), but at least it won't miss any of the predictors.

import re

import numpy as np
import pandas as pd
import polars as pl
import statsmodels.formula.api as smf
from marginaleffects import predictions
from marginaleffects.sanitize_model import sanitize_model

diamonds = pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/ggplot2/diamonds.csv")

# Complex formula with interaction term only, categorical with custom reference level, and spline
model = smf.ols("price ~ depth:color + C(cut, Treatment('Good')) + cr(np.minimum(carat, 0.8), df=5, constraints='center')", data = diamonds).fit()

# Fails: ValueError: There is no valid column name in `variables`.
predictions(model, newdata=diamonds, by ="cut")

# Create ModelStatsmodels object
self = sanitize_model(model)

# Variable list shows up empty
self.get_variables_names()

# Current code: Lines 53-56 in model_statsmodels.py
variables = self.model.model.exog_names
variables = [re.sub("\[.*\]", "", x) for x in variables]
variables = [x for x in variables if x in self.modeldata.columns]
variables = pl.Series(variables).unique().to_list()
# []

# Proposed code
formula = self.formula
columns = self.modeldata.columns
variables = list({var for var in columns if re.search(rf"\b{re.escape(var)}\b", formula)})
# ['price', 'carat', 'cut', 'color', 'depth']
vincentarelbundock commented 4 months ago

I like this a lot! Thanks for the suggestion.

vincentarelbundock commented 4 months ago

Thanks again for the report. Fixed and on pypi as 0.0.9

RoelVerbelen commented 4 months ago

Thank you for incorporating this, @vincentarelbundock and @LamAdr !