tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
597 stars 89 forks source link

Support for multi-level outcomes in MARS classification #472

Open vadimus202 opened 3 years ago

vadimus202 commented 3 years ago

Multi-level outcomes in MARS classification

Looking at the code in mars_data.R, it appears that only binary classification prediction is currently supported.

set_pred(
  model = "mars",
  eng = "earth",
  mode = "classification",
  type = "class",
  value = list(
    pre = NULL,
    post = function(x, object) {
      x <- ifelse(x[, 1] >= 0.5, object$lvl[2], object$lvl[1])
      x
    },
    func = c(fun = "predict"),
    args =
      list(
        object = quote(object$fit),
        newdata = quote(new_data),
        type = "response"
      )
  )
)

set_pred(
  model = "mars",
  eng = "earth",
  mode = "classification",
  type = "prob",
  value = list(
    pre = NULL,
    post = function(x, object) {
      x <- x[, 1]
      x <- tibble(v1 = 1 - x, v2 = x)
      colnames(x) <- object$lvl
      x
    },
    func = c(fun = "predict"),
    args =
      list(
        object = quote(object$fit),
        newdata = quote(new_data),
        type = "response"
      )
  )
)
juliasilge commented 3 years ago

We are fitting just fine in the multiclass case, it looks like, but the post handling does not account for multiclass:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(earth)
#> Loading required package: Formula
#> Loading required package: plotmo
#> Loading required package: plotrix
#> 
#> Attaching package: 'plotrix'
#> The following object is masked from 'package:scales':
#> 
#>     rescale
#> Loading required package: TeachingDemos

data("scat")
scat_df <- scat %>% na.omit()

mars_spec <- 
  mars(prod_degree = 2) %>%
  set_engine("earth") %>% 
  set_mode("classification")

mars_spec %>%
  fit(Species ~ ., data = scat_df)
#> parsnip model object
#> 
#> Fit time:  38ms 
#> GLM (family binomial, link logit):
#>           nulldev df       dev df   devratio     AIC iters converged
#> bobcat   125.2612 90   68.8815 86      0.450   78.88     5         1
#> coyote   105.0010 90   48.2778 86      0.540   58.28     6         1
#> gray_fox  87.6455 90   46.0286 86      0.475   56.03     9         1
#> 
#> Earth selected 5 of 31 terms, and 5 of 26 predictors
#> Termination condition: GRSq -10 at 31 terms
#> Importance: d13C, CN, Diameter, Mass, ropey, MonthAugust-unused, ...
#> Number of terms at each degree of interaction: 1 3 1
#> 
#> Earth
#>                GCV       RSS      GRSq       RSq
#> bobcat   0.1693942 11.913440 0.3306862 0.4711595
#> coyote   0.1117206  7.857276 0.4372281 0.5553408
#> gray_fox 0.1258790  8.853030 0.1894912 0.3595980
#> All      0.4069939 28.623745 0.3294039 0.4701463

Created on 2021-04-21 by the reprex package (v2.0.0)