topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.62k stars 632 forks source link

caret rpart doesnt work as rpart::rpart() #1057

Closed fahadshery closed 4 years ago

fahadshery commented 5 years ago

Hi,

I successfully created an rpart model by:

inTraining2 <- createDataPartition(complaints_4_trees$COMPLAINT_TYPE_SIMPLIFIED,p = 0.8,list = FALSE,times = 1)
train2 <- complaints_4_trees[inTraining2,]
test2 <- complaints_4_trees[-inTraining2,]

down_tree_new <- downSample(train2,train2$COMPLAINT_TYPE_SIMPLIFIED)

fitTree3 <- rpart(COMPLAINT_TYPE_SIMPLIFIED ~ ., data = down_tree_new,
                 method = "class")

rpart.plot(fitTree3)

fitTree3_predicted <- predict(fitTree3, test2, type = "class")

confusionMatrix(fitTree3_predicted,test2$COMPLAINT_TYPE_SIMPLIFIED)

I want to do the same using train() but having various problems. Here is how I am trying to build rpart using train:

##caret not being happy with the factor levels so further simplifying them by:

down_tree_new <- down_tree_new %>% mutate(COMPLAINT_TYPE_SIMPLIFIED = fct_recode(COMPLAINT_TYPE_SIMPLIFIED,
                                                                  "C4"  =  "Non ELC/HLC/MP (C4)",
                                                                   "Exec"   = "Exec Level"
                                                                ))

test2 <- test2 %>% mutate(COMPLAINT_TYPE_SIMPLIFIED = fct_recode(COMPLAINT_TYPE_SIMPLIFIED,
                                                                  "C4"  =  "Non ELC/HLC/MP (C4)",
                                                                   "Exec"   = "Exec Level"
                                                                ))

ctrl <- trainControl(method = "repeatedcv", number = 10, repeats = 5,
                   classProbs = TRUE, summaryFunction = multiClassSummary)

caret_down_fit1 <- train(COMPLAINT_TYPE_SIMPLIFIED ~ ., data = down_tree_new,
                          method = "rpart",
                          na.action = na.pass,
                         trControl=ctrl)

fancyRpartPlot(caret_down_fit1$finalModel)

pred <- predict(caret_down_fit1$finalModel, newdata = test2)

This gives the following error:

Error in eval(predvars, data, env) : object 'TOT_CONTCT_FOR_COMPLNT_28Dbin1' not found

However, this error goes away if I do:

pred <- predict(caret_down_fit1, newdata = test2)

Then it doesn't predict all the rows in test2:

confusionMatrix(pred, test2$COMPLAINT_TYPE_SIMPLIFIED)

Gives the following error:

Error in confusionMatrix.default(pred, test2$COMPLAINT_TYPE_SIMPLIFIED) : The data contain levels not found in the data.

Here is the data (Couldn't upload .RData file so saved it in .txt format:

test2.txt down_tree_new.txt

Session Info:

>sessionInfo()
R version 3.5.1 (2018-07-02)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS  10.14.5

Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib

locale:
[1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8

attached base packages:
[1] grid      parallel  stats     graphics  grDevices utils     datasets 
[8] methods   base     

other attached packages:
 [1] partykit_1.2-5         mvtnorm_1.0-11         libcoin_1.0-4         
 [4] rpart.plot_3.0.7       rattle_5.2.0           rpart_4.1-13          
 [7] doParallel_1.0.14      iterators_1.0.10       foreach_1.4.4         
[10] caret_6.0-84           lattice_0.20-35        textclean_0.9.3       
[13] qdap_2.3.2             RColorBrewer_1.1-2     qdapTools_1.3.3       
[16] qdapRegex_0.7.2        qdapDictionaries_1.0.7 data.table_1.11.8     
[19] stringi_1.2.4          ggrepel_0.8.0          fpp2_2.3              
[22] expsmooth_2.3          fma_2.3                forecast_8.7          
[25] recipes_0.1.5          textSummary_0.1.0      scales_1.0.0          
[28] forcats_0.4.0          stringr_1.3.1          dplyr_0.8.3           
[31] purrr_0.2.5            readr_1.1.1            tidyr_0.8.1           
[34] tibble_2.1.3           ggplot2_3.0.0          tidyverse_1.2.1       
[37] h2o_3.22.1.1          

loaded via a namespace (and not attached):
  [1] readxl_1.1.0        backports_1.1.4     plyr_1.8.4         
  [4] igraph_1.2.2        lazyeval_0.2.1      splines_3.5.1      
  [7] openNLP_0.2-6       TH.data_1.0-10      digest_0.6.19      
 [10] htmltools_0.3.6     gender_0.5.2        gdata_2.18.0       
 [13] fansi_0.4.0         MLmetrics_1.1.1     magrittr_1.5       
 [16] xlsx_0.6.1          tm_0.7-6            ROCR_1.0-7         
 [19] modelr_0.1.2        gower_0.1.2         extrafont_0.17     
 [22] matrixStats_0.54.0  wordcloud_2.6       sandwich_2.5-1     
 [25] xts_0.11-1          extrafontdb_1.0     tseries_0.10-46    
 [28] strucchange_1.5-1   colorspace_1.3-2    rvest_0.3.2        
 [31] haven_1.1.2         crayon_1.3.4        RCurl_1.95-4.11    
 [34] jsonlite_1.6        zeallot_0.1.0       survival_2.42-3    
 [37] zoo_1.8-4           glue_1.3.1          gtable_0.2.0       
 [40] ipred_0.9-7         Rttf2pt1_1.3.7      quantmod_0.4-13    
 [43] Rcpp_1.0.1          plotrix_3.7-4       Cubist_0.2.2       
 [46] Formula_1.2-3       stats4_3.5.1        lava_1.6.3         
 [49] prodlim_2018.04.18  httr_1.3.1          gplots_3.0.1       
 [52] modeltools_0.2-22   ellipsis_0.1.0      pkgconfig_2.0.2    
 [55] XML_3.98-1.19       rJava_0.9-10        openNLPdata_1.5.3-4
 [58] nnet_7.3-12         venneuler_1.1-0     utf8_1.1.4         
 [61] tidyselect_0.2.5    labeling_0.3        rlang_0.4.0        
 [64] reshape2_1.4.3      munsell_0.5.0       cellranger_1.1.0   
 [67] tools_3.5.1         cli_1.1.0           party_1.3-3        
 [70] generics_0.0.2      broom_0.5.0         evaluate_0.12      
 [73] yaml_2.2.0          ModelMetrics_1.2.0  knitr_1.20         
 [76] caTools_1.17.1.1    coin_1.3-0          nlme_3.1-137       
 [79] slam_0.1-43         xml2_1.2.0          compiler_3.5.1     
 [82] rstudioapi_0.8      curl_3.3            e1071_1.7-0        
 [85] Matrix_1.2-14       urca_1.3-0          vctrs_0.1.0        
 [88] pillar_1.4.1        lmtest_0.9-37       bitops_1.0-6       
 [91] R6_2.3.0            KernSmooth_2.23-15  C50_0.1.2          
 [94] gridExtra_2.3       codetools_0.2-15    reports_0.1.4      
 [97] MASS_7.3-50         gtools_3.8.1        assertthat_0.2.1   
[100] chron_2.3-53        xlsxjars_0.6.1      rprojroot_1.3-2    
[103] withr_2.1.2         fracdiff_1.4-2      multcomp_1.4-10    
[106] hms_0.4.2           quadprog_1.5-5      timeDate_3043.102  
[109] class_7.3-14        inum_1.0-1          rmarkdown_1.10     
[112] TTR_0.23-4          NLP_0.2-0           lubridate_1.7.4    
[115] base64enc_0.1-3 

I am new to ML so apologies in advance if I am doing something stupid :)

topepo commented 5 years ago

I am new to ML so apologies in advance if I am doing something stupid :)

No worries!

Some things that might not have been obvious that I would try:

I think that this second point is the issue. train.formula made your model using dummy variables (like the column TOT_CONTCT_FOR_COMPLNT_28Dbin1) but the data frame test2 only has a column TOT_CONTCT_FOR_COMPLNT_28D. Try using the predict function without specifying the finalModel element.