topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

Subset must be logical error #1272

Open arios92 opened 2 years ago

arios92 commented 2 years ago

Hi everyone,

I was trying to use heart disease dataset from kaggle to use caret to train rfRules. I have added the index vector that specifies which instances to be used as training dataset. However, an error popped out mentioning "Error in subset.default(x, subset) : 'subset' must be logical".

Kindly assist in rectifying the issue. I will be truly appreciate your help.

Please find the script as per below:

Library Import

library(stats)
library(rapportools)

library(caret)

library(imbalance)

library(MLmetrics)

library(ggplot2)
library(reshape2)

Data Loading & Data Exploration

heartfailure_df<-read.csv("heart_failure_clinical_records_dataset.csv",header=TRUE)

head(heartfailure_df,5)
summary(heartfailure_df)
str(heartfailure_df)

oversampling

In this section, we will look whether there is class imbalance problem in DEATH_EVENT.

with(heartfailure_df,table(DEATH_EVENT))

Based on the table above, it shows that close to 1/3 of the instances are died with heart disease. So, we need to find a method to balance the class size.

set.seed(48)
new_hf<-racog(heartfailure_df,numInstances=100,classAttr='DEATH_EVENT')
heartfailure_df_over<-rbind(heartfailure_df,new_hf)
summary(new_hf)
for(i in c(2,4,6,10,11,13)){
  heartfailure_df[,i]<-as.factor(heartfailure_df[,i])
  heartfailure_df_over[,i]<-as.factor(heartfailure_df_over[,i])
}
plotComparison(heartfailure_df, heartfailure_df_over, attrs = names(heartfailure_df)[1:3],classAttr='DEATH_EVENT')

logistic regression


set.seed(48)
train_hd_index<-createDataPartition(heartfailure_df_over$DEATH_EVENT,p=0.75,list=FALSE,times=1)

train_hd_over<-heartfailure_df_over[train_hd_index,]
test_hd_over<-heartfailure_df_over[-train_hd_index,]
with(train_hd_over,addmargins(table(train_hd_over$DEATH_EVENT)))
with(test_hd_over,addmargins(table(test_hd_over$DEATH_EVENT)))

Random Forest

library(inTrees)

rfTrain<-trainControl(method="repeatedcv",number=3,repeats=3)
rfGrid<-expand.grid(mtry=c(2:5),maxdepth=c(3:5))

rfFit<-train(x=heartfailure_df_over[,c(1:12)],y=heartfailure_df_over[,c(13)],data=heartfailure_df_over,method="rfRules",tuneGrid=rfGrid,trControl=rfTrain,subset=train_hd_over)
sessionInfo()

R version 4.1.0 (2021-05-18) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19044)

Matrix products: default

locale: [1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C
[5] LC_TIME=English_United Kingdom.1252

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] inTrees_1.2 reshape2_1.4.4 glmnet_4.1-3 boot_1.3-28 MLmetrics_1.1.1
[6] smotefamily_1.3.1 LiblineaR_2.10-12 imbalance_1.0.2.1 lme4_1.1-27.1 Matrix_1.3-4
[11] caret_6.0-90 lattice_0.20-44 ggplot2_3.3.5 MASS_7.3-54 rapportools_1.0
[16] reshape_0.8.8

loaded via a namespace (and not attached): [1] treemapify_2.5.5 nlme_3.1-152 lubridate_1.8.0 tools_4.1.0
[5] utf8_1.2.1 R6_2.5.1 DT_0.20 rpart_4.1-15
[9] arules_1.7-3 DBI_1.1.2 colorspace_2.0-1 nnet_7.3-16
[13] gbm_2.1.8 withr_2.4.3 tidyselect_1.1.1 compiler_4.1.0
[17] cli_3.1.0 flashClust_1.01-2 labeling_0.4.2 scales_1.1.1
[21] randomForest_4.6-14 proxy_0.4-26 stringr_1.4.0 digest_0.6.27
[25] minqa_1.2.4 rmarkdown_2.11 pkgconfig_2.0.3 htmltools_0.5.1.1
[29] parallelly_1.30.0 FactoMineR_2.4 htmlwidgets_1.5.4 rlang_0.4.11
[33] rstudioapi_0.13 shape_1.4.6 generics_0.1.1 farver_2.1.0
[37] jsonlite_1.7.2 zoo_1.8-9 dplyr_1.0.6 ModelMetrics_1.2.2.2 [41] magrittr_2.0.1 leaps_3.1 Rcpp_1.0.6 munsell_0.5.0
[45] fansi_0.5.0 ggfittext_0.9.1 lifecycle_1.0.1 bnlearn_4.7
[49] scatterplot3d_0.3-41 stringi_1.6.2 pROC_1.18.0 yaml_2.2.1
[53] plyr_1.8.6 recipes_0.1.17 grid_4.1.0 parallel_4.1.0
[57] listenv_0.8.0 ggrepel_0.9.1 crayon_1.4.2 splines_4.1.0
[61] pander_0.6.4 RRF_1.9.1 knitr_1.37 pillar_1.6.4
[65] xgboost_1.5.0.2 future.apply_1.8.1 codetools_0.2-18 stats4_4.1.0
[69] glue_1.4.2 evaluate_0.14 data.table_1.14.2 vctrs_0.3.8
[73] nloptr_1.2.2.3 foreach_1.5.1 gtable_0.3.0 purrr_0.3.4
[77] future_1.23.0 assertthat_0.2.1 xfun_0.29 gower_0.2.2
[81] prodlim_2019.11.13 xtable_1.8-4 e1071_1.7-9 class_7.3-19
[85] survival_3.2-11 timeDate_3043.102 tibble_3.1.2 iterators_1.0.13
[89] cluster_2.1.2 lava_1.6.10 globals_0.14.0 ellipsis_0.3.2
[93] ipred_0.9-12