vincentarelbundock / pymarginaleffects

GNU General Public License v3.0
57 stars 9 forks source link

Improve plotting functions #114

Closed vincentarelbundock closed 1 month ago

vincentarelbundock commented 2 months ago

Bring plotting functions to feature parity with R. In particular, examples like these should work:

Python versions

import statsmodels.formula.api as smf
from marginaleffects import *
from plotnine import *
import polars as pl

dat = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")

mod = smf.ols(
  "body_mass_g ~ flipper_length_mm * species * bill_length_mm + island", 
  data = dat.to_pandas()).fit()

# A
plot_predictions(mod, condition = ["flipper_length_mm", "species"])

# B
plot_predictions(mod, condition = ["flipper_length_mm", "bill_length_mm"])

# C
plot_predictions(mod, condition = {"flipper_length_mm": None, "species": ["Adelie", "Chinstrap"]})

# D
plot_predictions(mod, condition = {"flipper_length_mm": None, "species": ["Adelie", "Chinstrap"]})

# E
plot_predictions(mod, condition = {"flipper_length_mm": None, "bill_length_mm": "threenum"})

# F
plot_predictions(mod, condition = {"flipper_length_mm": None, "bill_length_mm": "fivenum"})

# G
plot_predictions(mod, condition = {"flipper_length_mm": None, "bill_length_mm": "minmax"})

R versions

library(marginaleffects)

dat = read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")

mod = glm(body_mass_g ~ flipper_length_mm * species * bill_length_mm + island, data = dat)

# A
plot_predictions(mod, condition = c("flipper_length_mm", "species"))

# B
plot_predictions(mod, condition = c("flipper_length_mm", "bill_length_mm"))

# C
plot_predictions(mod, condition = list("flipper_length_mm", "species" = c("Adelie", "Chinstrap")))

# D
plot_predictions(mod, condition = list("flipper_length_mm", "bill_length_mm" = "threenum"))

# E
plot_predictions(mod, condition = list("flipper_length_mm", "bill_length_mm" = "fivenum"))

# F
plot_predictions(mod, condition = list("flipper_length_mm", "bill_length_mm" = "minmax"))
artiom-matvei commented 1 month ago

Extra tests

R version

# H
plot_predictions(mod, condition = list("flipper_length_mm", "species" = c("Adelie", "Chinstrap"), "bill_length_mm"))

# I
plot_predictions(mod, condition = list("flipper_length_mm", "species" = c("Adelie", "Chinstrap"), "bill_length_mm", "island"))

Python version

# H
plot_predictions(mod, condition = {"flipper_length_mm": None, "species": ["Adelie", "Chinstrap"], "bill_length_mm": None})

# I
plot_predictions(mod, condition = {"flipper_length_mm": None, "species": ["Adelie", "Chinstrap"], "bill_length_mm": None, "island": None})

Looks like 4 variables also need to be added support for since example I yields:

AssertionError: Lenght of condition must be inclusively between 1 and 3. Got : 4.
vincentarelbundock commented 1 month ago

Yep, thanks. Also plot_comparisons() and plot_slopes()

artiom-matvei commented 1 month ago

The 4 variables are supported but the vertical and horizontal axis are inverted as seen in the screenshot below (LHS is Python, RHS is R). I would assume it needs to be fixed? image

vincentarelbundock commented 1 month ago

Yes, thanks. Should be consistent.

artiom-matvei commented 1 month ago

@vincentarelbundock if I am not wrong, it looks like the max number of arguments of condition is 4 in R and 3 in Python. So this would be another point to improve?

Python Version

def plot_slopes():
###...
    assert (
        len(var_list) < 4
    ), "The `condition` and `by` arguments can have a max length of 3."
###...

R Version

def plot_slopes():
###...
#' @param condition Conditional slopes
#' + Character vector (max length 4): Names of the predictors to display.
#' + Named list (max length 4): List names correspond to predictors. List elements can be:
#'   - Numeric vector
#'   - Function which returns a numeric vector or a set of unique categorical values 
#'   - Shortcut strings for common reference values: "minmax", "quartile", "threenum"
#' + 1: x-axis. 2: color/shape. 3: facet (wrap if no fourth variable, otherwise cols of grid). 4: facet (rows of grid).
#' + Numeric variables in positions 2 and 3 are summarized by Tukey's five numbers `?stats::fivenum`.

###...
vincentarelbundock commented 1 month ago

Good catch. Yes, sounds like a good improvement.

artiom-matvei commented 1 month ago

There was an error in the order in which the labels appear (left: after fix, right: before fix) image

It seems like it was due to the labels of the facets being treated as strings and ordered accordingly. I made sure the labels are not converted to strings by doing this modification in p9.py

def plot_common(model, dt, y_label, var_list):
###...
    # treat all variables except x-axis as categorical
    if len(var_list) > 1:
        for i in range(1, len(var_list)):
            if dt[var_list[i]].dtype.is_numeric() and i != 0 and i != 1:
                dt = dt.with_columns(pl.col(var_list[i]))
            elif dt[var_list[i]].dtype != pl.Categorical:
                dt = dt.with_columns(pl.col(var_list[i]).cast(pl.Utf8))
###...

There is a comment saying that the variables need to be treated as categorical, any idea why? @vincentarelbundock

vincentarelbundock commented 1 month ago

Categorical in the sense of "not continuous" so we draw points rather than lines.

I'm not sure if "ordered" matters or not. You can just check.

artiom-matvei commented 1 month ago

Subplot arrangement in a plot inconsistency

There is an inconsistency between R and Python in the way the subplots are arranged in a plot

R

dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")
mod <- glm("body_mass_g ~ flipper_length_mm * species * bill_length_mm * island", data = dat)
p <- plot_slopes(mod, variables = "species", 
                 condition = list("flipper_length_mm", "species"=c("Adelie", "Chinstrap"), "bill_length_mm")
)

image

Python

image

df = pl.read_csv(
    "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv",
    null_values="NA",
).drop_nulls()
mod = smf.ols(
    "body_mass_g ~ flipper_length_mm * species * bill_length_mm * island",
    df.to_pandas(),
).fit()
fig = plot_slopes(mod, 
              variables=["species"], 
              condition={
                    "flipper_length_mm": None,
                    "species": ["Adelie", "Chinstrap"],
                    "bill_length_mm": None,
              },
        )

What are your thoughts? Should we fix it in R? @vincentarelbundock

vincentarelbundock commented 1 month ago

Yeah I suppose python.is better in this case.

vincentarelbundock commented 1 month ago

Open an issue with the example, but don't work on that. Priority python.

artiom-matvei commented 1 month ago

Does it make sense to plot something like this? Specifically having species in both variables and condition?

R

In R we get a graph like this which seems correct: image

dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv")
mod <- glm("body_mass_g ~ flipper_length_mm * species * bill_length_mm * island", data = dat)
p <- plot_slopes(mod, variables = "species", 
                 condition = list("flipper_length_mm", "species"=c("Adelie", "Chinstrap"), "bill_length_mm", "island")
)

Python

In Python however we get some non-sense image

df = pl.read_csv(
    "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv",
    null_values="NA",
).drop_nulls()
mod = smf.ols(
    "body_mass_g ~ flipper_length_mm * species * bill_length_mm * island",
    df.to_pandas(),
).fit()
plot_slopes(
    mod,
    variables=["species"],
    condition={
        "flipper_length_mm": None,
        "species": ["Adelie", "Chinstrap"],
        "bill_length_mm": None,
        "island": None,
    },
)

If it makes sense, should I spend time debugging why the graph does not come up in Python? @vincentarelbundock

vincentarelbundock commented 1 month ago

No I don't think it makes much sense. Low priority.