tidymodels / embed

Extra recipes for predictor embeddings
https://embed.tidymodels.org
Other
141 stars 16 forks source link

catboost method to embed categorical variables #138

Open talegari opened 2 years ago

talegari commented 2 years ago

Hi Emil, I am planning to implement a step_catboost (on these lines). IMHO, it should belong here.

Let me know if you are open for PR?

juliasilge commented 2 years ago

Unfortunately catboost (the R package) is not on CRAN 😔 which is a blocker for us being able to implement catboost methods in our packages. You can see related discussion in catboost/catboost#439.

talegari commented 2 years ago

hey Julia, step_catboost would not depend on catboost package. The step involves involves permutations and target encoding. Here is the python implementation of the same.

EmilHvitfeldt commented 2 years ago

Hey @talegari 👋

That sounds great! Feel free to open an issue, and ping me if you need any help or assistance!

EmilHvitfeldt commented 1 year ago

Hello @talegari 👋 Are you still interested opening a PR for this step? if not, then I will do it

talegari commented 1 year ago

Hey @EmilHvitfeldt ... it just fell off the radar. I will submit a PR. I am planning on these lines. Let me know if you have a different suggestion.

EmilHvitfeldt commented 1 year ago

Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!

talegari commented 1 year ago

by 24th Mar

ಗುರು, ಮಾರ್ಚ್ 16, 2023 ರಂದು 09:34 ಅಪರಾಹ್ನ ಸಮಯಕ್ಕೆ ರಂದು Emil Hvitfeldt < @.***> ಅವರು ಬರೆದಿದ್ದಾರೆ:

Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!

— Reply to this email directly, view it on GitHub https://github.com/tidymodels/embed/issues/138#issuecomment-1472260970, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMTTW4C6ESAZ42ZCB7WVCLW4M2Y7ANCNFSM5ZQHRD2A . You are receiving this because you were mentioned.Message ID: @.***>

talegari commented 1 year ago

hey @EmilHvitfeldt , there was an unforseen thing that stopped me working on this. This is to let you know that I am on it and will raise a PR shortly.

EmilHvitfeldt commented 1 year ago

no problem! It might not make it into the next {embed} release, but that is fine, we can send it in later

talegari commented 1 year ago

@EmilHvitfeldt , I am one step away from raising a PR. I need your help in resolving a small issue. Here is the context:

I have implemented catboost encoder as a R6 class here:

Category encoder R6 class ```r # catboost encoder core logic pacman::p_load("tidyverse") #' catboost_encoder R6 class #' #' An R6 class to encode categorical variables with the CatBoost method. #' #' @name catboost_encoder #' @docType class #' @importFrom R6 R6Class #' #' @slot dataset The dataset to fit the encoder #' @slot mean The mean of the response variable in the dataset #' @slot varnames_to_encode The names of the categorical variables to encode #' @slot response_varname The name of the response variable in the dataset #' @slot is_fitted A flag indicating whether the encoder has been fitted #' @slot a A hyperparameter to control the strength of the encoding #' #' @section Public methods: \describe{ #' \item{\code{initialize(dataset)}}{Constructor method for the #' catboost_encoder class} \item{\code{fit(varnames_to_encode, #' response_varname, a = 1)}}{Fit the encoder to the data} #' \item{\code{transform(new_data = NULL)}}{Transform a new dataset using the #' fitted encoder} } #' #' @section Private methods: \describe{ \item{\code{encode_with_y(df, #' varname_to_encode, response_varname)}}{Encode a categorical variable using #' the response variable} \item{\code{encode_without_y(df, varname_to_encode, #' response_varname)}}{Encode a categorical variable without using the #' response variable} } #' #' @section Usage #' #' catboost_encoder <- catboost_encoder$new(dataset) #' catboost_encoder$fit(varnames_to_encode, response_varname) #' encoded_data <- catboost_encoder$transform(new_data) #' #' @export catboost_encoder catboost_encoder = R6::R6Class( "catboost_encoder", public = list( dataset = NULL, mean = NULL, varnames_to_encode = NULL, response_varname = NULL, is_fitted = FALSE, a = NULL, encode_novel_levels = NULL, encode_missing_levels = NULL, initialize = function(dataset){ checkmate::assert_data_frame(dataset) self$dataset = dataset return(invisible(NULL)) }, fit = function(varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE, encode_missing_levels = FALSE ){ checkmate::assert_string(response_varname) checkmate::assert_subset(response_varname, choices = colnames(self$dataset) ) checkmate::assert_numeric(self$dataset[[response_varname]], any.missing = FALSE ) checkmate::assert_character(varnames_to_encode) checkmate::assert_subset(varnames_to_encode, choices = colnames(self$dataset) ) for (avarname in varnames_to_encode){ checkmate::assert_factor(self$dataset[[avarname]]) } checkmate::assert_number(a) checkmate::assert_flag(encode_novel_levels) checkmate::assert_flag(encode_missing_levels) self$varnames_to_encode = varnames_to_encode self$response_varname = response_varname self$mean = mean(self$dataset[[response_varname]], na.rm = TRUE) self$a = a self$encode_novel_levels = TRUE self$encode_missing_levels = FALSE self$is_fitted = TRUE return(invisible(NULL)) }, transform = function(new_data = NULL){ new_data_is_null = TRUE if (!is.null(new_data)){ checkmate::assert_data_frame(new_data) checkmate::assert_false(self$response_varname %in% colnames(new_data)) names_sorted = sort(colnames(new_data)) checkmate::assert_set_equal(colnames(new_data), setdiff(colnames(self$dataset), self$response_varname ) ) checkmate::assert_set_equal( sapply(new_data, class)[names_sorted], sapply(dplyr::select(self$dataset, -c(self$response_varname)) , class )[names_sorted] ) new_data_is_null = FALSE } if (!self$is_fitted){ stop("please 'fit' before 'transform'") } if (new_data_is_null){ message("transforming on the dataset") new_data = self$dataset } if (new_data_is_null){ encoded_cols = map(self$varnames_to_encode, ~ private$encode_with_y(new_data, .x) ) } else { encoded_cols = map(self$varnames_to_encode, ~ private$encode_without_y(new_data,.x) ) } names(encoded_cols) = self$varnames_to_encode res = as_tibble(encoded_cols) %>% bind_cols(select(new_data, -c(self$varnames_to_encode))) %>% relocate(colnames(new_data)) # encode novel (in new data case only) if (self$encode_novel_levels && !new_data_is_null){ for (avarname in self$varnames_to_encode){ new_levels = setdiff(levels(new_data[[avarname]]), levels(self$dataset[[avarname]]) ) if (length(new_levels) > 0){ res[[avarname]] = ifelse(new_data[[avarname]] %in% new_levels, self$mean, res[[avarname]] ) } } } # encode missing (in new data case only) if (self$encode_missing_levels && !new_data_is_null){ for (avarname in self$varnames_to_encode){ res[[avarname]][ is.na(new_data[[avarname]]) ] = NA } } return(res) } ), private = list( encode_with_y = function(df, varname_to_encode){ # new levels: not applicable # NA: encoded res = df %>% select(all_of(c(varname_to_encode, self$response_varname))) %>% group_by(.data[[varname_to_encode]]) %>% mutate(cs__ = cumsum(.data[[self$response_varname]]), cc__ = row_number() - 1L ) %>% ungroup() %>% transmute({{varname_to_encode}} := (cs__ - .data[[self$response_varname]] + mean(.data[[self$response_varname]], na.rm = TRUE) * self$a ) / (cc__ + self$a) ) %>% pull() return(res) }, encode_without_y = function(df, varname_to_encode){ # new levels: NA # NA: NA level_means = "level_means__" agg_frame = self$dataset %>% select(all_of(c(varname_to_encode, self$response_varname))) %>% group_by(.data[[varname_to_encode]]) %>% summarise(sum__ = sum(.data[[self$response_varname]], na.rm = TRUE), count__ = n() ) %>% ungroup() %>% mutate(level_means__ = ifelse(count__ == 1, self$mean, (sum__ + self$mean * self$a) / (count__ + self$a) ) ) %>% drop_na(all_of(varname_to_encode)) %>% select(all_of(c(varname_to_encode, level_means))) res = df %>% select(all_of(c(varname_to_encode))) %>% left_join(agg_frame, by = varname_to_encode) %>% pull(level_means) return(res) } ) ) ```
recipe wrapper as 'step_catboost' ```r step_catboost = function(recipe, ..., role = NA, trained = FALSE, outcome = NULL, mapping = NULL, skip = FALSE, id = rand_id("catboost") ){ if (is.null(outcome)) { rlang::abort("Please list a variable in `outcome`") } recipes:::add_step( recipe, step_catboost_new( terms = enquos(...), role = role, trained = trained, outcome = outcome, mapping = mapping, skip = skip, id = id ) ) } step_catboost_new = function(terms, role, trained, outcome, mapping, skip, id ){ step( subclass = "catboost", terms = terms, role = role, trained = trained, outcome = outcome, mapping = mapping, skip = skip, id = id ) } #' @export prep.step_catboost = function(x, training, info = NULL, ... ){ col_names = recipes_eval_select(x$terms, training, info) if (length(col_names) > 0) { y_name = recipes_eval_select(x$outcome, training, info) # instantiate R6 class obj ce = catboost_encoder$new(training) ce$fit(varnames_to_encode = col_names, response_varname = y_name ) } else { ce = list() } step_catboost_new( terms = x$terms, role = x$role, trained = TRUE, outcome = x$outcome, mapping = ce, skip = x$skip, id = x$id ) } #' @export bake.step_catboost = function(object, new_data, ...) { if (!is.null(new_data)){ y_name = purrr::map_chr(object$outcome, rlang::as_name) # string ce = object$mapping if (y_name %in% colnames(new_data)){ new_data[[y_name]] = NULL } res = ce$transform(new_data) } else { res = ce$transform() } res = ce$transform(new_data) return(res) } #' @rdname required_pkgs.embed #' @export required_pkgs.step_catboost = function(x, ...) { c("embed") } ```
Example ``` r pacman::p_load("recipes", "tidyverse") source("~/personal/catboost_encoding_r6.R") #> transforming on the dataset #> transforming on the dataset source("~/personal/step_catboost.R") pen1 = palmerpenguins::penguins %>% drop_na(bill_length_mm) %>% slice_sample(prop = 0.7, by = 'species') pen2 = palmerpenguins::penguins %>% drop_na(bill_length_mm) %>% setdiff(pen1) # example with R6 class ce = catboost_encoder$new(pen1) ce$fit(c('species', 'sex'), response_varname = 'bill_length_mm') # when input to transofrm is empty, it uses the training dataset # (here it is pen1) ce$transform() #> transforming on the dataset #> # A tibble: 238 × 8 #> species island bill_length_mm bill_depth_mm flipper_…¹ body_…² sex year #> #> 1 43.8 Torgersen 39.6 17.2 196 3550 43.8 2008 #> 2 41.7 Dream 37.5 18.9 179 2975 43.8 2007 #> 3 40.3 Biscoe 35.5 16.2 195 3350 41.7 2008 #> 4 39.1 Torgersen 40.6 19 199 4000 43.8 2009 #> 5 39.4 Biscoe 40.1 18.9 188 4300 42.2 2008 #> 6 39.5 Dream 39.6 18.8 190 4600 41.5 2007 #> 7 39.5 Dream 32.1 15.5 188 3050 39.6 2009 #> 8 38.6 Dream 39.8 19.1 184 4650 41.0 2007 #> 9 38.7 Torgersen 34.1 18.1 193 3475 40.6 2007 #> 10 38.3 Dream 37 16.9 185 3000 37.7 2007 #> # … with 228 more rows, and abbreviated variable names ¹​flipper_length_mm, #> # ²​body_mass_g # transform on a new dataset ce$transform(pen2 %>% select(-bill_length_mm)) #> # A tibble: 104 × 7 #> species island bill_depth_mm flipper_length_mm body_mass_g sex year #> #> 1 38.7 Torgersen 18 195 3250 42.2 2007 #> 2 38.7 Torgersen 20.6 190 3650 45.6 2007 #> 3 38.7 Torgersen 17.8 181 3625 42.2 2007 #> 4 38.7 Torgersen 19.6 195 4675 45.6 2007 #> 5 38.7 Torgersen 21.2 191 3800 45.6 2007 #> 6 38.7 Torgersen 17.8 185 3700 42.2 2007 #> 7 38.7 Torgersen 20.7 197 4500 45.6 2007 #> 8 38.7 Torgersen 21.5 194 4200 45.6 2007 #> 9 38.7 Biscoe 18.6 172 3150 42.2 2007 #> 10 38.7 Dream 16.7 178 3250 42.2 2007 #> # … with 94 more rows # example with step_catboost recipe ar = recipe(bill_length_mm ~ ., data = pen1) %>% step_catboost(species, outcome = "bill_length_mm") %>% prep(training = pen1) ar #> Recipe #> #> Inputs: #> #> role #variables #> outcome 1 #> predictor 7 #> #> Training data contained 238 data points and 9 incomplete rows. #> #> Operations: #> #> $terms #> > #> #> [[1]] #> #> expr: ^species #> env: 0x7fbbb5a65120 #> #> #> $role #> [1] NA #> #> $trained #> [1] TRUE #> #> $outcome #> [1] "bill_length_mm" #> #> $mapping #> #> Public: #> a: 1 #> clone: function (deep = FALSE) #> dataset: tbl_df, tbl, data.frame #> encode_missing_levels: FALSE #> encode_novel_levels: TRUE #> fit: function (varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE, #> initialize: function (dataset) #> is_fitted: TRUE #> mean: 43.7655462184874 #> response_varname: bill_length_mm #> transform: function (new_data = NULL) #> varnames_to_encode: species #> Private: #> encode_with_y: function (df, varname_to_encode) #> encode_without_y: function (df, varname_to_encode) #> #> $skip #> [1] FALSE #> #> $id #> [1] "catboost_LGVzz" #> #> attr(,"class") #> [1] "step_catboost" "step" ar %>% juice() #> # A tibble: 238 × 7 #> species island bill_depth_mm flipper_length_mm body_mass_g sex year #> #> 1 38.7 Torgersen 17.2 196 3550 female 2008 #> 2 38.7 Dream 18.9 179 2975 2007 #> 3 38.7 Biscoe 16.2 195 3350 female 2008 #> 4 38.7 Torgersen 19 199 4000 male 2009 #> 5 38.7 Biscoe 18.9 188 4300 male 2008 #> 6 38.7 Dream 18.8 190 4600 male 2007 #> 7 38.7 Dream 15.5 188 3050 female 2009 #> 8 38.7 Dream 19.1 184 4650 male 2007 #> 9 38.7 Torgersen 18.1 193 3475 2007 #> 10 38.7 Dream 16.9 185 3000 female 2007 #> # … with 228 more rows ar %>% bake(new_data = NULL) #> # A tibble: 238 × 7 #> species island bill_depth_mm flipper_length_mm body_mass_g sex year #> #> 1 38.7 Torgersen 17.2 196 3550 female 2008 #> 2 38.7 Dream 18.9 179 2975 2007 #> 3 38.7 Biscoe 16.2 195 3350 female 2008 #> 4 38.7 Torgersen 19 199 4000 male 2009 #> 5 38.7 Biscoe 18.9 188 4300 male 2008 #> 6 38.7 Dream 18.8 190 4600 male 2007 #> 7 38.7 Dream 15.5 188 3050 female 2009 #> 8 38.7 Dream 19.1 184 4650 male 2007 #> 9 38.7 Torgersen 18.1 193 3475 2007 #> 10 38.7 Dream 16.9 185 3000 female 2007 #> # … with 228 more rows ar %>% bake(new_data = pen1) #> # A tibble: 238 × 7 #> species island bill_depth_mm flipper_length_mm body_mass_g sex year #> #> 1 38.7 Torgersen 17.2 196 3550 female 2008 #> 2 38.7 Dream 18.9 179 2975 2007 #> 3 38.7 Biscoe 16.2 195 3350 female 2008 #> 4 38.7 Torgersen 19 199 4000 male 2009 #> 5 38.7 Biscoe 18.9 188 4300 male 2008 #> 6 38.7 Dream 18.8 190 4600 male 2007 #> 7 38.7 Dream 15.5 188 3050 female 2009 #> 8 38.7 Dream 19.1 184 4650 male 2007 #> 9 38.7 Torgersen 18.1 193 3475 2007 #> 10 38.7 Dream 16.9 185 3000 female 2007 #> # … with 228 more rows ar %>% bake(new_data = pen2) #> # A tibble: 104 × 7 #> species island bill_depth_mm flipper_length_mm body_mass_g sex year #> #> 1 38.7 Torgersen 18 195 3250 female 2007 #> 2 38.7 Torgersen 20.6 190 3650 male 2007 #> 3 38.7 Torgersen 17.8 181 3625 female 2007 #> 4 38.7 Torgersen 19.6 195 4675 male 2007 #> 5 38.7 Torgersen 21.2 191 3800 male 2007 #> 6 38.7 Torgersen 17.8 185 3700 female 2007 #> 7 38.7 Torgersen 20.7 197 4500 male 2007 #> 8 38.7 Torgersen 21.5 194 4200 male 2007 #> 9 38.7 Biscoe 18.6 172 3150 female 2007 #> 10 38.7 Dream 16.7 178 3250 female 2007 #> # … with 94 more rows ```

Issue: The ce$transform() and ar %>% bake(new_data = NULL) give different results. How do I resolve this?

EmilHvitfeldt commented 1 year ago

Hello @talegari Sorry for taking a while to answer.

I'm not terrible familiar with {R6} so I'm not sure how much I can help you. However, I can tell you where something might happen. In bake.step_catboost() you have

  if (!is.null(new_data)){
    y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
    ce = object$mapping
    if (y_name %in% colnames(new_data)){
      new_data[[y_name]] = NULL
    }
    res = ce$transform(new_data)
  } else {
    res = ce$transform()
  }

I'm assuming that you thought this was needed to deal with bake(new_data = NULL). This is actually not the case, the data passed to any bake method will always be a non-NULL tibble. What is happening when you call bake(new_data = NULL) is that it extracts ar$template and does a couple of other things. So it just extracts the data we got when running prep/bake() the first time.

Secondly, I'm sad to say since you put in a lot of effort, but I don't want to include {R6} and {checkmate} as dependencies just to include this step. If you don't want to go through the work on translating away from {R6} and {checkmate} I understand, and If you want I can take over and do the last parts.

Thanks again for all the work!