rstudio / reticulate

R Interface to Python
https://rstudio.github.io/reticulate
Apache License 2.0
1.67k stars 327 forks source link

Module alias isn't recognized inside nested purrr map loop #712

Open ercbk opened 4 years ago

ercbk commented 4 years ago

The problem

Module aliases in the global? environment aren't propagating through the my purrr::map2 loop. I've started at the core of the nested loop and at every level the sklearn alias is recognized. Only at the outermost layer does it fail. Other reticulate functions such as r_to_py don't seem to have this problem.

When I place the import inside the model function, sklearn_rf_FUN, itself, it works.There doesn't seem to be any speed costs associated with putting the import inside the model function even though it's repeatedly being called. I've also tried executing the import both inside and outside, and it's actually faster when NOT having it on the outside of the model function. Although I only benchmarked it a couple times. It's an easy fix and there doesn't seem to be a cost, but I thought it was odd and may have consequences elsewhere.

Reproducible example

pacman::p_load(tidymodels, data.table, dtplyr, dplyr, furrr, reticulate)

sklearn_e <- import("sklearn.ensemble")
# import("sklearn.ensemble", as = "sklearn_e")

plan(multiprocess)

set.seed(2019)

# simulated data; generates 10 multi-patterned, numeric predictors plus outcome variable
sim_data <- function(n) {
      tmp <- mlbench::mlbench.friedman1(n, sd=1)
      tmp <- cbind(tmp$x, tmp$y)
      tmp <- as.data.frame(tmp)
      names(tmp)[ncol(tmp)] <- "y"
      tmp
}

# Use small data to tune and compare models
small_dat <- sim_data(100)

ncv_dat_2 <- nested_cv(small_dat,
                       outside = vfold_cv(v = 2, repeats = 2),
                       inside = bootstraps(times = 2))

error_FUN <- function(y_obs, y_hat){
      y_obs <- unlist(y_obs)
      y_hat <- unlist(y_hat)
      Metrics::mae(y_obs, y_hat)
}

sklearn_rf_FUN <- function(params, analysis_set) {
      # sklearn_e <- import("sklearn.ensemble")
      max_features <- r_to_py(params$mtry[[1]])
      n_estimators <- r_to_py(params$trees[[1]])

      # get data into sklearn's preferred format
      y_idx <- ncol(analysis_set) - 1
      X_idx <- y_idx - 1 
      pAnal <- r_to_py(analysis_set)
      y_train <- pAnal$iloc(axis = 1L)[y_idx]$values
      X_train <- pAnal$iloc(axis = 1L)[0:X_idx]

      model <- sklearn_e$RandomForestRegressor(criterion = "mae",
                                               max_features = max_features,
                                               n_estimators = n_estimators,
                                               random_state = 1L)
      mod_fit <- model$fit(X_train, y_train)
}

rf_params <- grid_latin_hypercube(
      mtry(range = c(3, 4)),
      trees(range = c(200, 300)),
      size = 40
)

is_sklearn <- function(modfun) {
      string <- toString(body(modfun))
      stringr::str_detect(string, pattern = "sklearn")
}

# inputs params, model, and resample, calls model and error functions, outputs error
mod_error <- function(params, mod_FUN, dat) {
      y_col <- ncol(dat$data)
      y_obs <- assessment(dat)[y_col]
      mod <- mod_FUN(params, analysis(dat))

      if(is_sklearn(mod_FUN)) {
            X_dat <- r_to_py(assessment(dat)[-y_col])
            pred <- mod$predict(X_dat)
      } else {
            pred <- predict(mod, assessment(dat))
            if (!is.data.frame(pred)) {
                  pred <- pred$predictions
            }
      }

      error <- error_FUN(y_obs, pred)
      error
}

# inputs resample, loops hyperparam grid values to model/error function, collects error value for hyperparam combo
tune_over_params <- function(dat, mod_FUN, params) {
      params$error <- map_dbl(1:nrow(params), function(row) {
            params <- params[row,]
            mod_error(params, mod_FUN, dat)
      })
      params
}

# inputs and sends fold's resamples to tuning function, collects and averages fold's error for each hyperparameter combo
summarize_tune_results <- function(dat, mod_FUN, params) {
      # Return row-bound tibble that has the 25 bootstrap results
      param_names <- names(params)
      future_map_dfr(dat$splits, tune_over_params, mod_FUN, params, .progress = TRUE) %>%
            lazy_dt(., key_by = param_names) %>% 
            # For each value of the tuning parameter, compute the
            # average <error> which is the inner bootstrap estimate.
            group_by_at(vars(param_names)) %>%
            summarize(mean_error = mean(error, na.rm = TRUE),
                      n = length(error)) %>% 
            as_tibble()
}

compare_algs <- function(mod_FUN, params, ncv_dat){
      # tune models by grid searching on resamples in the inner loop (e.g. 5 repeats 10 folds = list of 50 tibbles with param and mean_error cols)
      tuning_results <- map(ncv_dat$inner_resamples, summarize_tune_results, mod_FUN, params)
}

mod_FUN_list_skrf <- list(sklearn_rf = sklearn_rf_FUN)

params_list <- list(sklearn_rf = rf_params)

algorithm_comparison_two_skrf <- map2(mod_FUN_list_skrf, params_list, compare_algs, ncv_dat_2)
#>  Progress:                                                     100% Progress:                                                     100% Progress: --------------------------------------------------- 100%
#> Error in mod_FUN(params, analysis(dat)): object 'sklearn_e' not found
#> Called from: signalConditions(obj, exclude = getOption("future.relay.immediate", 
    "immediateCondition"), resignal = resignal, ...)

Created on 2020-02-04 by the reprex package (v0.3.0)

current session info ```r - Session info -------------------------------------------------------------------------------------------------------------------- setting value version R version 3.6.2 (2019-12-12) os Windows 10 x64 system x86_64, mingw32 ui RStudio language (EN) collate English_United States.1252 ctype English_United States.1252 tz America/New_York date 2020-02-05 - Packages ------------------------------------------------------------------------------------------------------------------------ package * version date lib source assertthat 0.2.1 2019-03-21 [1] CRAN (R 3.6.1) backports 1.1.5 2019-10-02 [1] CRAN (R 3.6.1) base64enc 0.1-3 2015-07-28 [1] CRAN (R 3.6.0) bayesplot 1.7.1 2019-12-01 [1] CRAN (R 3.6.2) boot 1.3-24 2019-12-20 [1] CRAN (R 3.6.2) broom * 0.5.3 2019-12-14 [1] CRAN (R 3.6.2) callr 3.4.0 2019-12-09 [1] CRAN (R 3.6.2) class 7.3-15 2019-01-01 [2] CRAN (R 3.6.2) cli 2.0.1 2020-01-08 [1] CRAN (R 3.6.2) clipr 0.7.0 2019-07-23 [1] CRAN (R 3.6.1) codetools 0.2-16 2018-12-24 [2] CRAN (R 3.6.2) colorspace 1.4-1 2019-03-18 [1] CRAN (R 3.6.1) colourpicker 1.0 2017-09-27 [1] CRAN (R 3.6.1) crayon 1.3.4 2017-09-16 [1] CRAN (R 3.6.1) crosstalk 1.0.0 2016-12-21 [1] CRAN (R 3.6.1) data.table * 1.12.8 2019-12-09 [1] CRAN (R 3.6.2) desc 1.2.0 2018-05-01 [1] CRAN (R 3.6.1) details * 0.2.1 2020-01-12 [1] CRAN (R 3.6.2) dials * 0.0.4 2019-12-02 [1] CRAN (R 3.6.2) DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 3.6.1) digest 0.6.23 2019-11-23 [1] CRAN (R 3.6.2) dplyr * 0.8.3 2019-07-04 [1] CRAN (R 3.6.1) DT 0.11 2019-12-19 [1] CRAN (R 3.6.2) dtplyr * 1.0.0 2019-11-12 [1] CRAN (R 3.6.2) dygraphs 1.1.1.6 2018-07-11 [1] CRAN (R 3.6.1) fansi 0.4.1 2020-01-08 [1] CRAN (R 3.6.2) fastmap 1.0.1 2019-10-08 [1] CRAN (R 3.6.1) foreach 1.4.7 2019-07-27 [1] CRAN (R 3.6.1) furrr * 0.1.0 2018-05-16 [1] CRAN (R 3.6.1) future * 1.16.0 2020-01-16 [1] CRAN (R 3.6.2) generics 0.0.2 2018-11-29 [1] CRAN (R 3.6.1) ggplot2 * 3.2.1 2019-08-10 [1] CRAN (R 3.6.1) ggridges 0.5.2 2020-01-12 [1] CRAN (R 3.6.2) globals 0.12.5 2019-12-07 [1] CRAN (R 3.6.1) glue 1.3.1 2019-03-12 [1] CRAN (R 3.6.1) gower 0.2.1 2019-05-14 [1] CRAN (R 3.6.1) GPfit 1.0-8 2019-02-08 [1] CRAN (R 3.6.2) gridExtra 2.3 2017-09-09 [1] CRAN (R 3.6.1) gtable 0.3.0 2019-03-25 [1] CRAN (R 3.6.1) gtools 3.8.1 2018-06-26 [1] CRAN (R 3.6.0) htmltools 0.4.0 2019-10-04 [1] CRAN (R 3.6.1) htmlwidgets 1.5.1 2019-10-08 [1] CRAN (R 3.6.1) httpuv 1.5.2 2019-09-11 [1] CRAN (R 3.6.1) httr 1.4.1 2019-08-05 [1] CRAN (R 3.6.1) igraph 1.2.4.2 2019-11-27 [1] CRAN (R 3.6.2) infer * 0.5.1 2019-11-19 [1] CRAN (R 3.6.2) inline 0.3.15 2018-05-18 [1] CRAN (R 3.6.1) ipred 0.9-9 2019-04-28 [1] CRAN (R 3.6.1) iterators 1.0.12 2019-07-26 [1] CRAN (R 3.6.1) janeaustenr 0.1.5 2017-06-10 [1] CRAN (R 3.6.1) jsonlite 1.6 2018-12-07 [1] CRAN (R 3.6.1) knitr 1.27 2020-01-16 [1] CRAN (R 3.6.2) later 1.0.0 2019-10-04 [1] CRAN (R 3.6.1) lattice 0.20-38 2018-11-04 [2] CRAN (R 3.6.2) lava 1.6.6 2019-08-01 [1] CRAN (R 3.6.1) lazyeval 0.2.2 2019-03-15 [1] CRAN (R 3.6.1) lhs 1.0.1 2019-02-03 [1] CRAN (R 3.6.1) lifecycle 0.1.0 2019-08-01 [1] CRAN (R 3.6.1) listenv 0.8.0 2019-12-05 [1] CRAN (R 3.6.2) lme4 1.1-21 2019-03-05 [1] CRAN (R 3.6.1) loo 2.2.0 2019-12-19 [1] CRAN (R 3.6.2) lubridate 1.7.4 2018-04-11 [1] CRAN (R 3.6.1) magrittr 1.5 2014-11-22 [1] CRAN (R 3.6.1) markdown 1.1 2019-08-07 [1] CRAN (R 3.6.1) MASS 7.3-51.4 2019-03-31 [2] CRAN (R 3.6.2) Matrix 1.2-18 2019-11-27 [2] CRAN (R 3.6.2) matrixStats 0.55.0 2019-09-07 [1] CRAN (R 3.6.1) mime 0.8 2019-12-19 [1] CRAN (R 3.6.2) miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 3.6.1) minqa 1.2.4 2014-10-09 [1] CRAN (R 3.6.1) munsell 0.5.0 2018-06-12 [1] CRAN (R 3.6.1) nlme 3.1-143 2019-12-10 [1] CRAN (R 3.6.2) nloptr 1.2.1 2018-10-03 [1] CRAN (R 3.6.1) nnet 7.3-12 2016-02-02 [2] CRAN (R 3.6.2) pacman 0.5.1 2019-03-11 [1] CRAN (R 3.6.1) parsnip * 0.0.5 2020-01-07 [1] CRAN (R 3.6.2) pillar 1.4.3 2019-12-20 [1] CRAN (R 3.6.2) pkgbuild 1.0.6 2019-10-09 [1] CRAN (R 3.6.1) pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 3.6.1) plyr 1.8.5 2019-12-10 [1] CRAN (R 3.6.2) png 0.1-7 2013-12-03 [1] CRAN (R 3.6.0) prettyunits 1.1.0 2020-01-09 [1] CRAN (R 3.6.2) pROC 1.16.1 2020-01-14 [1] CRAN (R 3.6.2) processx 3.4.1 2019-07-18 [1] CRAN (R 3.6.1) prodlim 2019.11.13 2019-11-17 [1] CRAN (R 3.6.2) promises 1.1.0 2019-10-04 [1] CRAN (R 3.6.1) ps 1.3.0 2018-12-21 [1] CRAN (R 3.6.1) purrr * 0.3.3 2019-10-18 [1] CRAN (R 3.6.2) R6 2.4.1 2019-11-12 [1] CRAN (R 3.6.2) Rcpp 1.0.3 2019-11-08 [1] CRAN (R 3.6.2) recipes * 0.1.9 2020-01-07 [1] CRAN (R 3.6.2) reshape2 1.4.3 2017-12-11 [1] CRAN (R 3.6.1) reticulate * 1.14 2019-12-17 [1] CRAN (R 3.6.2) rlang 0.4.2 2019-11-23 [1] CRAN (R 3.6.2) rpart 4.1-15 2019-04-12 [2] CRAN (R 3.6.2) rprojroot 1.3-2 2018-01-03 [1] CRAN (R 3.6.1) rsample * 0.0.5 2019-07-12 [1] CRAN (R 3.6.1) rsconnect 0.8.16 2019-12-13 [1] CRAN (R 3.6.2) rstan 2.19.2 2019-07-09 [1] CRAN (R 3.6.1) rstanarm 2.19.2 2019-10-03 [1] CRAN (R 3.6.1) rstantools 2.0.0 2019-09-15 [1] CRAN (R 3.6.1) rstudioapi 0.10 2019-03-19 [1] CRAN (R 3.6.1) scales * 1.1.0 2019-11-18 [1] CRAN (R 3.6.2) sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 3.6.1) shiny 1.4.0 2019-10-10 [1] CRAN (R 3.6.1) shinyjs 1.1 2020-01-13 [1] CRAN (R 3.6.2) shinystan 2.5.0 2018-05-01 [1] CRAN (R 3.6.1) shinythemes 1.1.2 2018-11-06 [1] CRAN (R 3.6.1) SnowballC 0.6.0 2019-01-15 [1] CRAN (R 3.6.0) StanHeaders 2.21.0-1 2020-01-19 [1] CRAN (R 3.6.2) stringi 1.4.5 2020-01-11 [1] CRAN (R 3.6.2) stringr 1.4.0 2019-02-10 [1] CRAN (R 3.6.1) survival 3.1-8 2019-12-03 [1] CRAN (R 3.6.2) threejs 0.3.1 2017-08-13 [1] CRAN (R 3.6.1) tibble * 2.1.3 2019-06-06 [1] CRAN (R 3.6.1) tictoc * 1.0 2014-06-17 [1] CRAN (R 3.6.0) tidymodels * 0.0.3 2019-10-04 [1] CRAN (R 3.6.1) tidyposterior 0.0.2 2018-11-15 [1] CRAN (R 3.6.1) tidypredict 0.4.3 2019-09-03 [1] CRAN (R 3.6.1) tidyr * 1.0.0 2019-09-11 [1] CRAN (R 3.6.1) tidyselect 0.2.5 2018-10-11 [1] CRAN (R 3.6.1) tidytext 0.2.2 2019-07-29 [1] CRAN (R 3.6.1) timeDate 3043.102 2018-02-21 [1] CRAN (R 3.6.0) tokenizers 0.2.1 2018-03-29 [1] CRAN (R 3.6.1) vctrs 0.2.1 2019-12-17 [1] CRAN (R 3.6.2) withr 2.1.2 2018-03-15 [1] CRAN (R 3.6.1) workflows 0.1.0 2019-12-30 [1] CRAN (R 3.6.2) xfun 0.12 2020-01-13 [1] CRAN (R 3.6.2) xml2 1.2.2 2019-08-09 [1] CRAN (R 3.6.1) xtable 1.8-4 2019-04-21 [1] CRAN (R 3.6.1) xts 0.12-0 2020-01-19 [1] CRAN (R 3.6.2) yardstick * 0.0.4 2019-08-26 [1] CRAN (R 3.6.1) zeallot 0.1.0 2018-01-28 [1] CRAN (R 3.6.1) zoo 1.8-7 2020-01-10 [1] CRAN (R 3.6.2) [1] C:/Users/tbats/Documents/R/win-library/3.6 [2] C:/Program Files/R/R-3.6.2/library ```


kevinushey commented 4 years ago

Sorry, but this example is not quite minimal and so it will be difficult for us to investigate.

Assuming I'm reading your code correctly, I suspect the use of multiprocessing may be a culprit -- you likely need to explicitly import sklearn_e inside any function that is going to be run on a separate worker.

ercbk commented 4 years ago

Oof. Yeah I probably could've whittled the code down a bit more. Sorry about that. What you're saying makes sense especially given the error. It's just that the future_map_dfr occurs in summarize_tune_results which is the 2nd layer, and when I started the loop from there, it worked fine without import being in sklearn_rf_FUN.