vincentarelbundock / pymarginaleffects

GNU General Public License v3.0
49 stars 9 forks source link

Data column named "group" causes duplicate column issue for `predictions(..., by = ...)` #90

Closed RoelVerbelen closed 7 months ago

RoelVerbelen commented 7 months ago

Hi @vincentarelbundock

I thought I'd flag this one with you as well: predictions() with the by argument fails when the data has a column named "group".

Reprex:

import pandas as pd
import statsmodels.formula.api as smf
from marginaleffects import predictions

diamonds = pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/ggplot2/diamonds.csv")

model = smf.ols("price ~ cut", data = diamonds).fit()

# Works
predictions(model, newdata=diamonds, by ="cut")

# Create column named group
diamonds["group"] = diamonds["color"]

# Fails
predictions(model, newdata=diamonds, by ="cut")

DuplicateError: column with name 'group' has more than one occurrences
vincentarelbundock commented 7 months ago

Ah yeah, that's right. It's one of the annoying choices I made, to "reserve" some keywords to avoid conflicts between the original data we merge in the output, and the marginaleffects-produced columns.

@LamAdr If you have a minute, do you think you could add a sanity check to newdata (or the model data if newdata is None)? See if there's a reserved name like group or estimate, or conf.low, and return an informative error message.

Thanks!