bethatkinson / rpart

Recursive Partitioning and Regression Trees
46 stars 24 forks source link

Error 'data' must be a data.frame, environment, or list #30

Open thamron opened 3 years ago

thamron commented 3 years ago

Fail to convert rpart object to class party for plotting. The codes are shown below:

`library(tidyverse) library(mdsr) url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data" census <- read_csv( url, col_names = c( "age", "workclass", "fnlwgt", "education", "education_1", "marital_status", "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss", "hours_per_week", "native_country", "income" ) ) %>% mutate(income = factor(income))

library(tidymodels) set.seed(364) n <- nrow(census) census_parts <- census %>% initial_split(prop = 0.8) train <- census_parts %>% training() test <- census_parts %>% testing() pi_bar <- train %>% count(income) %>% mutate(pct = n / sum(n)) %>% filter(income == ">50K") %>% pull(pct) pi_bar

form <- as.formula( "income ~ age + workclass + education + marital_status + occupation + relationship + race + sex + capital_gain + capital_loss + hours_per_week" )

mod_tree <- decision_tree(mode = "classification") %>% set_engine("rpart") %>% fit(form, data = train) mod_tree

library(rpart) library(partykit) plot(as.party(mod_tree$fit)) `

bethatkinson commented 3 years ago

The class of mod_tree is not "rpart" so there is no guarantee that it will work using downstream packages. This is caused by the "tidymodels" package, not rpart itself.

class(mod_tree) [1] "_rpart" "model_fit"

fit <- rpart(form, data=train)

plot(as.party(fit)) ## This works


From: thamron @.> Sent: Tuesday, August 24, 2021 9:29 PM To: bethatkinson/rpart @.> Cc: Subscribed @.***> Subject: [EXTERNAL] [bethatkinson/rpart] Error 'data' must be a data.frame, environment, or list (#30)

Fail to convert rpart object to class party for plotting. The codes are shown below:

`library(tidyverse) library(mdsr) url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data" census <- read_csv( url, col_names = c( "age", "workclass", "fnlwgt", "education", "education_1", "marital_status", "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss", "hours_per_week", "native_country", "income" ) ) %>% mutate(income = factor(income))

library(tidymodels) set.seed(364) n <- nrow(census) census_parts <- census %>% initial_split(prop = 0.8) train <- census_parts %>% training() test <- census_parts %>% testing() pi_bar <- train %>% count(income) %>% mutate(pct = n / sum(n)) %>% filter(income == ">50K") %>% pull(pct) pi_bar

form <- as.formula( "income ~ age + workclass + education + marital_status + occupation + relationship + race + sex + capital_gain + capital_loss + hours_per_week" )

mod_tree <- decision_tree(mode = "classification") %>% set_engine("rpart") %>% fit(form, data = train) mod_tree

library(rpart) library(partykit) plot(as.party(mod_tree$fit)) `

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHubhttps://github.com/bethatkinson/rpart/issues/30, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ACWQG5YOATIFUJRRLUTPHJ3T6RIR5ANCNFSM5CYBY7GQ. Triage notifications on the go with GitHub Mobile for iOShttps://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Androidhttps://play.google.com/store/apps/details?id=com.github.android&utm_campaign=notification-email.