giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
492 stars 87 forks source link

Error : all(feature.class %in% names(feature.types)) is not TRUE #63

Closed curtisburkhalter closed 5 years ago

curtisburkhalter commented 5 years ago

Hello I'm attempting to obtain SHAP values from my GBM model created in H2O. I'm running into issues though when I try to use the iml::Predictor$new() function. I'm really sorry to put this here, but I've looked everywhere for the error I'm receiving. I've double checked that the features used to train the model are the same as those in the test set, but I keep getting the error referenced in the subject line here. I was wondering if you might be able to explain why this might be happening. I know I haven't provided a reproducible example here, but the data I'm using to create my model is confidential and can't be shared so if you can even provide some idea of what this error might be telling me I would greatly appreciate it.

j-hartshorn commented 5 years ago

You can share a version of your code that reproduces the error except with open source data? As it stands it is very hard to help.

christophM commented 5 years ago

Can you try with the github version? Use devtools::install_github("christophm/iml" )

curtisburkhalter commented 5 years ago

Hello,

I tried with the GitHub version and still getting the same error message. Here is the code I'm using to get the error:

1. create a data frame with just the features

features_iml <- as.data.frame(df_testR) %>% dplyr::select(-returned)

2. Create a vector with the actual responses

response_iml <- as.numeric(as.vector(df_testR$returned))

3. Create custom predict function that returns the predicted values as a

vector (probability of customer churn in my example)

pred <- function(model, newdata) { results <- as.data.frame(h2o.predict(model, as.h2o(newdata))) return(results[[3L]]) }

4. example of prediction output

pred(GBM5, features_iml) %>% head()

5. create Predictor object

predictor = Predictor$new(model = GBM5, data = features_iml, y = response_iml, predict.fun = pred, class = "classification")

Here are also so basic descriptions of the dataset and model object I'm using in the code above:

class(GBM5)

[1] "H2OBinomialModel"

attr(,"package")

[1] "h2o"

class(df_testR)

[1] "tbl_df" "tbl" "data.frame"

dim(df_testR)

[1] 47006 44

If there is anything else I can provide that might be insightful please let me know. I greatly appreciate you taking the time to look at this.

Best, Curtis

On Mon, Dec 10, 2018 at 7:19 AM Christoph Molnar notifications@github.com wrote:

Can you try with the github version? Use devtools::install_github("christophm/iml" )

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/christophM/iml/issues/63#issuecomment-445796417, or mute the thread https://github.com/notifications/unsubscribe-auth/AbeXj_Bok6MiVU5eO1PY8pfyXg68ucDOks5u3lE9gaJpZM4ZJG0d .

curtisburkhalter commented 5 years ago

@christophM here is a sample of anonymized data:

structure(list(dlr_id_cur = c(1, 2), date_eff = structure(c(16014, 15416), class = "Date"), new_vec_ind = structure(c(1L, 1L), .Label = c("NNA", "UNA"), class = "factor"), cntrct_term = c(9587879614862828, 19), amt_financed = c(9455359, 65561175), reg_payment = c(885288, 389371), acct_stat_cd = structure(c(3L, 3L), .Label = c("11", "22", "33"), class = "factor"), base_rental = c(1, 626266), down_pymt = c(2, 6654661), car_count = c(5, 1), dur_lease = c(3974, 6466), returned = structure(1:2, .Label = c("00", "11"), class = "factor"), state = structure(c(10L, 1L), .Label = c("ANA", "BNA", "CNA", "DNA", "FNA", "GNA", "HNA", "INA", "KNA", "LNA", "MNA", "NNA", "ONA", "PNA", "QNA", "RNA", "SNA", "TNA", "UNA", "VNA", "WNA"), class = "factor"), zip = c(34633, 45222), zip_two_digits = structure(c(71L, 36L), .Label = c("00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "110", "111", "112", "113", "114", "115", "116", "117", "118", "119", "220", "221", "222", "223", "224", "225", "226", "227", "228", "229", "330", "331", "332", "333", "334", "335", "336", "337", "338", "339", "440", "441", "442", "443", "444", "445", "446", "447", "448", "449", "550", "551", "552", "553", "554", "555", "556", "557", "558", "559", "660", "661", "662", "663", "664", "665", "666", "667", "668", "669", "770", "771", "772", "773", "774", "775", "776", "777", "778", "779", "880", "881", "882", "883", "884", "885", "886", "887", "888", "889", "990", "991", "992", "993", "994", "995", "996", "997", "998", "999", "ANA", "BNA", "CNA", "ENA", "GNA", "HNA", "JNA", "KNA", "LNA", "MNA", "NNA", "PNA", "RNA", "SNA", "TNA", "VNA" ), class = "factor") , mod_year_date = c(8156, 6278), vehic_mod_fam_code = structure(c(2L, 2L), .Label = c("BNA", "CNA", "ENA", "MNA", "SNA", "TNA", "VNA", "XNA"), class = "factor"), mod_class_code = structure(c(4L, 2L ), .Label = c("BNA", "CNA", "ENA", "GNA", "MNA", "RNA", "SNA" ), class = "factor"), count_dl_DL_CDE_CSPS_A_NP = c(945, 337), DL_CDE_CSPS_A_NP_avg_dl = c(3355188283749626, 8835582388327814 ), count_sv_DL_CDE_CSPS_A_NP = c(6532, 8475), DL_CDE_CSPS_A_NP_avg_sv = c(4471193398278526, 6934672627789796), count_dl_NUM_CSPS_INIT_SCR = c(774, 773 ), NUM_CSPS_INIT_SCR_avg_dl = c(9468453388562312, 5847816458727333 ), count_sv_NUM_CSPS_INIT_SCR = c(2467, 3882), NUM_CSPS_INIT_SCR_avg_sv = c(5857936629789154, 8963457353776469), count_FFV = c(8563, 2566), average_FFV = c(25697792913881564, 13693335921646120), csps_NUM_SV = c(8, 6), avg_SV_rating = c(9817541424596360, 6218928542331853), csps_FFV_ratio = c(23125612473476952, 2), avg_DL_rating = c(2182256921592387, 7668957586431513), has_DL_rating = c(1, 8), has_bad_DL_rating = c(2, 4), serv_has_MNT = c(7, 3), serv_has_SCP = c(5, 4), serv_has_ELW = c(9, 4), serv_has_LCP = c(7, 1), ro_count = c(6, 1), ro_tot_cust_pay = c(2, 188759), ro_tot_pay = c(3, 764372), date_eff_weekday = structure(c(4L, 3L), .Label = c("FNA", "MNA", "SNA", "TNA", "WNA"), class = "factor"), date_eff_month_int = c(83, 7), date_eff_day = c(2, 24)), .Names = c("dlr_id_cur", "date_eff", "new_vec_ind", "cntrct_term", "amt_financed", "reg_payment", "acct_stat_cd", "base_rental", "down_pymt", "car_count", "dur_lease", "returned", "state", "zip", "zip_two_digits", "mod_year_date", "vehic_mod_fam_code", "mod_class_code", "count_dl_DL_CDE_CSPS_A_NP", "DL_CDE_CSPS_A_NP_avg_dl", "count_sv_DL_CDE_CSPS_A_NP", "DL_CDE_CSPS_A_NP_avg_sv", "count_dl_NUM_CSPS_INIT_SCR", "NUM_CSPS_INIT_SCR_avg_dl", "count_sv_NUM_CSPS_INIT_SCR", "NUM_CSPS_INIT_SCR_avg_sv", "count_FFV", "average_FFV", "csps_NUM_SV", "avg_SV_rating", "csps_FFV_ratio", "avg_DL_rating", "has_DL_rating", "has_bad_DL_rating", "serv_has_MNT", "serv_has_SCP", "serv_has_ELW", "serv_has_LCP", "ro_count", "ro_tot_cust_pay", "ro_tot_pay", "date_eff_weekday", "date_eff_month_int", "date_eff_day"), row.names = 1:2, class = "data.frame")

curtisburkhalter commented 5 years ago

I went back and looked at the source code for the package, specifically 'utils.R'. It seems that where I have a column of class 'Date' this is causing an issue because it is not one of the accepted 'feature.types'. It seems that I should go back and change this to a numeric. Sorry for bringing up such a trivial issue, but I think it can be closed now.

christophM commented 5 years ago

Thanks for sharing the eventual problem that caused the error.

ck37 commented 5 years ago

I also ran into this issue - it would be helpful to the user to have a clearer error message. E.g. one could report which classes and specific variables are causing the error. That would really help to speed the identification and resolution of the problem.

DavZim commented 3 years ago

I've encountered this error when I had logical variables. The error is thrown in the function get.feature.type(), where logical is not defined. Not sure if this is intended or a missing feature. I solved it by converting the logical variables to integers first.

walinchus commented 2 years ago

This was helpful. I encountered the same error but it was able to work when I dropped the geometry column.