ModelOriented / DALEX

moDel Agnostic Language for Exploration and eXplanation
https://dalex.drwhy.ai
GNU General Public License v3.0
1.38k stars 166 forks source link

R: H2O Random Forest Classification Difficulties #226

Closed stefan-grujic closed 4 years ago

stefan-grujic commented 4 years ago

Hi,

I'm interested in using DALEX in a university project regarding interoperable ML in the context of microbial genomics. I prepared a taster dataset in order test out a very basic pipeline for using your package. I seem to be having some issues explaining a random forest classification model built with H2O.

Here is the output of h2o_explain() from DALEXtra:

>  explained = DALEXtra::explain_h2o(model = model,
+                                   data = h2o_X,
+                                   y = h2o_Y,
+                                   label = "h2o_rfe",
+                                   type = 'classification')

Preparation of a new explainer is initiated
  -> model label       :  h2o_rfe 
  -> data              :  10  rows  145488  cols 
  -> target variable   :  10  values 
  -> model_info        :  package h2o , ver. 3.30.0.2 , task classification (  default  ) 
  -> model_info        :  type set to  classification 
  -> predict function  :  yhat.H2OBinomialModel  will be used (  default  )
  |===================================| 100%
  -> predicted values  :  numerical, min =  0.2033333 , mean =  0.5383333 , max =  0.8566667  
  -> residual function :  difference between y and yhat (  default  )
  |===================================| 100%
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  
Warning message:
In DALEXtra::explain_h2o(model = model, data = h2o_X, y = h2o_Y,  :
  NAs introduced by coercion 

After playing around a lot with different suggestions for the predict function, my presumption is that the problem is coming from the residual function since that appears to be where the NA's are being introduced.

Here is the output of model_performance():

> model_performance(explained)
Measures for:  classification
recall   : NA 
precision: NA 
f1       : NA 
accuracy : NA 
auc      : NA

Residuals:
Error in quantile.default(x$residuals$diff, seq(0, 1, 0.1)) : 
  missing values and NaN's not allowed if 'na.rm' is FALSE

Additionally, here is the output of variable_importance():

> variable_importance(explained)

ERROR: Unexpected HTTP Status code: 400 Bad Request (url = http://localhost:54321/99/Rapids)

java.lang.IllegalArgumentException
 [1] "java.lang.IllegalArgumentException: H2O does not currently reorder rows, please sort your row selection first"
 [2] "    water.rapids.ast.prims.mungers.AstRowSlice.apply(AstRowSlice.java:46)"                                    
 [3] "    water.rapids.ast.prims.mungers.AstRowSlice.apply(AstRowSlice.java:20)"                                    
 [4] "    water.rapids.ast.AstExec.exec(AstExec.java:63)"                                                           
 [5] "    water.rapids.ast.prims.assign.AstTmpAssign.apply(AstTmpAssign.java:48)"                                   
 [6] "    water.rapids.ast.prims.assign.AstTmpAssign.apply(AstTmpAssign.java:17)"                                   
 [7] "    water.rapids.ast.AstExec.exec(AstExec.java:63)"                                                           
 [8] "    water.rapids.Session.exec(Session.java:85)"                                                               
 [9] "    water.rapids.Rapids.exec(Rapids.java:94)"                                                                 
[10] "    water.api.RapidsHandler.exec(RapidsHandler.java:38)"                                                      
[11] "    jdk.internal.reflect.GeneratedMethodAccessor8.invoke(Unknown Source)"                                     
[12] "    java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)" 
[13] "    java.base/java.lang.reflect.Method.invoke(Method.java:566)"                                               
[14] "    water.api.Handler.handle(Handler.java:60)"                                                                
[15] "    water.api.RequestServer.serve(RequestServer.java:470)"                                                    
[16] "    water.api.RequestServer.doGeneric(RequestServer.java:301)"                                                
[17] "    water.api.RequestServer.doPost(RequestServer.java:227)"                                                   
[18] "    javax.servlet.http.HttpServlet.service(HttpServlet.java:755)"                                             
[19] "    javax.servlet.http.HttpServlet.service(HttpServlet.java:848)"                                             
[20] "    org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:684)"                                   
[21] "    org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:501)"                               
[22] "    org.eclipse.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1086)"                       
[23] "    org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:427)"                                
[24] "    org.eclipse.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1020)"                        
[25] "    org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:135)"                            
[26] "    org.eclipse.jetty.server.handler.HandlerCollection.handle(HandlerCollection.java:154)"                    
[27] "    org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:116)"                          
[28] "    water.webserver.jetty8.Jetty8ServerAdapter$LoginHandler.handle(Jetty8ServerAdapter.java:119)"             
[29] "    org.eclipse.jetty.server.handler.HandlerCollection.handle(HandlerCollection.java:154)"                    
[30] "    org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:116)"                          
[31] "    org.eclipse.jetty.server.Server.handle(Server.java:370)"                                                  
[32] "    org.eclipse.jetty.server.AbstractHttpConnection.handleRequest(AbstractHttpConnection.java:494)"           
[33] "    org.eclipse.jetty.server.BlockingHttpConnection.handleRequest(BlockingHttpConnection.java:53)"            
[34] "    org.eclipse.jetty.server.AbstractHttpConnection.content(AbstractHttpConnection.java:984)"                 
[35] "    org.eclipse.jetty.server.AbstractHttpConnection$RequestHandler.content(AbstractHttpConnection.java:1045)" 
[36] "    org.eclipse.jetty.http.HttpParser.parseNext(HttpParser.java:861)"                                         
[37] "    org.eclipse.jetty.http.HttpParser.parseAvailable(HttpParser.java:236)"                                    
[38] "    org.eclipse.jetty.server.BlockingHttpConnection.handle(BlockingHttpConnection.java:72)"                   
[39] "    org.eclipse.jetty.server.bio.SocketConnector$ConnectorEndPoint.run(SocketConnector.java:264)"             
[40] "    org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:608)"                         
[41] "    org.eclipse.jetty.util.thread.QueuedThreadPool$3.run(QueuedThreadPool.java:543)"                          
[42] "    java.base/java.lang.Thread.run(Thread.java:834)"                                                          

Error in .h2o.doSafeREST(h2oRestApiVersion = h2oRestApiVersion, urlSuffix = page,  : 

ERROR MESSAGE:

H2O does not currently reorder rows, please sort your row selection first

Here is the full script I wrote in case it would be useful in identifying the issue:

library(pacman)
p_load('DALEX', 'DALEXtra', 'h2o', 'tidyverse')

h2o.init()

data = read.csv('/media/newdrive/mod_7_data/model_input/gpa_vcf.csv')
X = data[, -which(names(data) == 'SPE')]
Y = data$SPE

h2odata = as.h2o(data)
h2o_X = as.h2o(X)
h2o_Y = as.h2o(Y)

model = h2o.randomForest(training_frame = h2odata, y = 'SPE', model_id = "RF_1")

explained = DALEXtra::explain_h2o(model = model,
                                  data = h2o_X,
                                  y = h2o_Y,
                                  label = "h2o_rfe",
                                  type = 'classification')

model_performance(explained)
variable_importance(explained)

I apologise if this is a trivial issue, and look forward to your response.

Thanks, Stefan Grujic

maksymiuks commented 4 years ago

Hi, thank You for the issue. I believe the problem may be related that DALEXtra is not yet adjusted to DALEX 1.2.2 so multiclass classification does not work as intended. There is a developer version of next release available at the DALEXtra_0.2.3 branch which can be installed using remotes::install_github("ModelOriented/DALEXtra", "DALEXtra_0.2.3"). Keep in mind that is requires newest version of DALEX remotes::install_github("ModelOriented/DALEX"). After that, all should work (I cannot check since I don't have access to Your data). Just set the type to multiclass or let DALEX set in on its own.

Feel free to ask in case of any problems.

stefan-grujic commented 4 years ago

Hi,

I've made the changes suggested and I'm getting the following:

> explained = explain_h2o(model = model,
+                     data = h2o_X,
+                     y = h2o_Y,
+                     label = "h2o rfe",
+                     type = 'multiclass')

Preparation of a new explainer is initiated
  -> model label       :  h2o rfe 
  -> data              :  10  rows  145488  cols 
  -> target variable   :  10  values 
  -> target variable   :  Please note that 'y' is a factor.  (  WARNING  )
  -> target variable   :  Consider changing the 'y' to a logical or numerical vector.
  -> target variable   :  Otherwise I will not be able to calculate residuals or loss function.
  -> predict function  :  yhat.H2OBinomialModel  will be used (  default  )
  |=================================| 100%
  -> predicted values  :  numerical, min =  0.1 , mean =  0.4233333 , max =  0.7233333  
  |=================================| 100%
  -> model_info        :  package h2o , ver. 3.30.0.2 , task classification (  default  ) 
  -> model_info        :  type set to  multiclass 
  -> residual function :  difference between 1 and probability of true class (  default  )
  |=================================| 100%
  -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
  A new explainer has been created!  

Do you suppose one-hot encoding the target variable will resolve the issue?

maksymiuks commented 4 years ago

Could You provide Your sessionInfo()? Also what is the result of class(model)? one-hot encoding the target variable will not solve the problem. Am I correct that the task is multiclass classification?

Btw. Are You sure that data has a proper format? 145488 features and 10 observations look weird.

hbaniecki commented 4 years ago

I see

  -> target variable   :  Please note that 'y' is a factor.  (  WARNING  )
  -> target variable   :  Consider changing the 'y' to a logical or numerical vector.

I guess that this warning should be changed.

stefan-grujic commented 4 years ago

Problem is currently Binomial, but the same error flags when left as default so I thought I'd also try multinomial before I posted. I'll copy over the default results for sanity purposes:

> explained = explain_h2o(model = model,
+                     data = h2o_X,
+                     y = h2o_Y,
+                     label = "h2o rfe")
Preparation of a new explainer is initiated
  -> model label       :  h2o rfe 
  -> data              :  10  rows  145488  cols 
  -> target variable   :  10  values 
  -> target variable   :  Please note that 'y' is a factor.  (  WARNING  )
  -> target variable   :  Consider changing the 'y' to a logical or numerical vector.
  -> target variable   :  Otherwise I will not be able to calculate residuals or loss function.
  -> predict function  :  yhat.H2OBinomialModel  will be used (  default  )
  |=================================| 100%
  -> predicted values  :  numerical, min =  0.1 , mean =  0.4233333 , max =  0.7233333  
  |=================================| 100%
  -> model_info        :  package h2o , ver. 3.30.0.2 , task classification (  default  ) 
  -> residual function :  difference between y and yhat (  default  )
  |=================================| 100%
  -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
  A new explainer has been created

I'll give you some brief insight into what the data is just for clarity's sake; I am taking binary expression data of certain genes and mutations within strains of E.coli isolated in 5 Human samples and 5 Poultry samples and using it to create a predictive model which can distinguish the host sites. The data is correctly formatted as far as I can tell, I'm just using a small sample size to test the pipeline I'm creating, the issue will become one of multinomial classification in the real run. H2O's functions such as h2o.predict() and h2o.varimp() seem to be running as expected, so I also have no reason to suspect that the model is faulty.

Here's my sessionInfo():

R version 3.6.3 (2020-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1

locale:
 [1] LC_CTYPE=en_GB.UTF-8      
 [2] LC_NUMERIC=C              
 [3] LC_TIME=en_GB.UTF-8       
 [4] LC_COLLATE=en_GB.UTF-8    
 [5] LC_MONETARY=en_GB.UTF-8   
 [6] LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=en_GB.UTF-8      
 [8] LC_NAME=C                 
 [9] LC_ADDRESS=C              
[10] LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_GB.UTF-8
[12] LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils    
[5] datasets  methods   base     

other attached packages:
[1] DALEXtra_0.2.2 DALEX_1.2.2   
[3] h2o_3.30.0.2   pacman_0.5.1  

loaded via a namespace (and not attached):
 [1] reticulate_1.15   shape_1.4.4      
 [3] tidyselect_1.0.0  remotes_2.1.1    
 [5] purrr_0.3.3       lattice_0.20-41  
 [7] colorspace_1.4-1  vctrs_0.2.4      
 [9] rlang_0.4.6       pkgbuild_1.0.8   
[11] pillar_1.4.4      glue_1.4.0       
[13] withr_2.2.0       rappdirs_0.3.1   
[15] foreach_1.5.0     lifecycle_0.2.0  
[17] munsell_0.5.0     gtable_0.3.0     
[19] codetools_0.2-16  callr_3.4.3      
[21] ps_1.3.3          curl_4.3         
[23] fansi_0.4.1       Rcpp_1.0.4.6     
[25] scales_1.1.0      backports_1.1.6  
[27] jsonlite_1.6.1    ingredients_1.2.0
[29] gridExtra_2.3     ggplot2_3.3.0    
[31] processx_3.4.2    dplyr_0.8.4      
[33] grid_3.6.3        rprojroot_1.3-2  
[35] cli_2.0.2         tools_3.6.3      
[37] bitops_1.0-6      magrittr_1.5     
[39] RCurl_1.98-1.2    glmnet_3.0-2     
[41] tibble_3.0.1      ggdendro_0.1-20  
[43] crayon_1.3.4      pkgconfig_2.0.3  
[45] ellipsis_0.3.0    MASS_7.3-51.6    
[47] Matrix_1.2-18     data.table_1.12.8
[49] prettyunits_1.1.1 assertthat_0.2.1 
[51] rstudioapi_0.11   iterators_1.0.12 
[53] R6_2.4.1          compiler_3.6.3 

I noticed here that the session info is identifying my DALEXtra version as being 0.2.2 despite me having tried to force it:

remotes::install_github("ModelOriented/DALEXtra", "DALEXtra_0.2.3", force=TRUE)
remotes::install_github("ModelOriented/DALEX")

The class is indeed Binomial:

> class(model)
[1] "H2OBinomialModel"
attr(,"package")
[1] "h2o"

I'll try removing DALEX and DALEXtra, freshly installing them, and running everything again in case that resolves the version control issue.

maksymiuks commented 4 years ago

Ok, I've inspected everything, and I have to apologize because I've misled You. For an unknown reason, I thought it is a multinomial model although it clearly says it is binomial.

In that case, type should be indeed classification, but y vector should be 0 and 1 numeric vector where 1 denotes positive class. Factor type is required when task is multiclass.

As @hbaniecki says warning should be reworked in the future.