marjoleinF / pre

an R package for deriving Prediction Rule Ensembles
58 stars 17 forks source link

speed up and bug fixes #5

Closed boennecd closed 7 years ago

boennecd commented 7 years ago

This focuses here is on speed ups but there are also some bug fixes

Alle the results below are from running:

> summary(microbenchmark::microbenchmark(
+   pre = airq.ens <- pre(Ozone ~ ., data=airquality[complete.cases(airquality),]), 
+   times = 100))

The initial result was:

  expr      min       lq    mean   median       uq      max neval
1  pre 4.018802 4.171205 4.33873 4.295319 4.407155 6.901721   100

After commit 6f298c8:

  expr     min       lq    mean   median       uq      max neval
1  pre 3.57509 3.670837 3.86512 3.793877 3.998536 4.700677   100

After commit d237fb1:

  expr      min       lq     mean  median       uq      max neval
1  pre 2.312484 2.355866 2.455929 2.40926 2.498872 3.090584   100

After commit 7b008b0:

  expr      min      lq     mean  median       uq      max neval
1  pre 1.804114 1.90265 2.039485 1.97219 2.083083 3.855207   100

The present result is:

  expr      min       lq     mean   median       uq      max neval
1  pre 1.376896 1.448779 1.524643 1.487653 1.582981 2.037187   100

Further reduction in computation time can be achieved by looking at the varies calls to the functions in the partykit package. There are a lot of validations in these functions that makes them slow. Though, the biggest computational cost is in the training of the trees. The partykit package is a pure R code so I suspect a that compiled version of CTREE like in the party package could reduce computation time further. The part for classification with learning rates have not been optimized. Though, a similar step to that performed to partykit::ctree can be made

I don't get the not here:

complements <- !(names(rulevars) %in% removed_complement_rules)
rulevars <- rulevars[,!complements]
rules <- rules[!complements]

I removed the ! from !(....) and updated the tests. It reduced the CV error in the example you send from:

mean cv error (se) = 339.6366 (92.21531) 

... to:

mean cv error (se) = 299.798 (74.28342) 

All the test are made against the updated version. The current version of the removal of complements is quite slow when there are a lot of rules. I updated it to a new version which is still slow. Thus, I have implemented an alternative. You can get it by calling with use_suggestion = TRUE. If you do not like it then remove (in any case remove the comment!):

# TODO: remove this if you do not want the suggestion I propose
if(use_suggestion && removeduplicates && removecomplements){
...

To illustrate that the new approach is faster and as good then consider these calls:

> data(PimaIndiansDiabetes, package = "mlbench")
> nrow(PimaIndiansDiabetes) # a bit larger
[1] 768
> set.seed(seed <- 9602165)
> system.time(
+   f1 <- pre(diabetes ~ ., data = PimaIndiansDiabetes, maxdepth = 3, 
+             use_suggestion = FALSE))
   user  system elapsed 
  44.84    1.17   46.47 
> f1

Final ensemble with cv error within 1se of minimum: 
  lambda =  0.01991541
  number of terms = 20
  mean cv error (se) = 0.9591347 (0.03588186) 

         rule  coefficient                                          description
       rule82   0.62514441                          glucose > 127 & mass > 29.9
     rule1180  -0.41465507           glucose <= 154 & pregnant <= 7 & age <= 28
      rule580  -0.35054073  glucose <= 127 & pedigree <= 0.637 & glucose <= 111
      rule919  -0.30923561             glucose <= 154 & mass <= 26.3 & mass > 0
     rule1336  -0.25223230                        glucose <= 154 & mass <= 27.3
     rule1502  -0.23085816            glucose <= 154 & age <= 31 & mass <= 45.3
      rule649  -0.22895438           glucose <= 127 & pregnant <= 6 & age <= 34
     rule1627  -0.19438094        glucose <= 151 & mass <= 45.3 & pregnant <= 6
     rule1695  -0.19401622        glucose <= 123 & mass <= 45.3 & pregnant <= 6
      rule167  -0.19050765    glucose <= 132 & pedigree <= 0.821 & mass <= 38.8
      rule518   0.14996597                            glucose > 154 & age <= 57
     rule1344  -0.14887229    glucose <= 111 & pedigree <= 0.629 & mass <= 38.1
        rule9  -0.14325183        glucose <= 154 & mass <= 30.9 & pregnant <= 5
  (Intercept)   0.14008263                                                 <NA>
      rule992  -0.13422342     glucose <= 134 & glucose <= 107 & insulin <= 148
      rule416  -0.13358126                        glucose <= 143 & mass <= 28.8
      rule636  -0.09022366       glucose <= 154 & glucose <= 99 & insulin <= 86
        rule1  -0.02752370                        glucose <= 145 & mass <= 27.3
      rule335   0.02325657                          glucose > 123 & mass > 29.9
      rule832  -0.01011519        glucose <= 144 & pregnant <= 6 & mass <= 30.8
     rule1738  -0.00630662       glucose <= 154 & glucose <= 103 & mass <= 37.3
> set.seed(seed)
> system.time(
+   f2 <- pre(diabetes ~ ., data = PimaIndiansDiabetes, maxdepth = 3, 
+             use_suggestion = TRUE))
   user  system elapsed 
  16.39    0.06   16.86 
> f2

Final ensemble with cv error within 1se of minimum: 
  lambda =  0.01991541
  number of terms = 20
  mean cv error (se) = 0.9591347 (0.03588186) 

         rule  coefficient                                          description
       rule82   0.62514441                          glucose > 127 & mass > 29.9
     rule1180  -0.41465507           glucose <= 154 & pregnant <= 7 & age <= 28
      rule580  -0.35054073  glucose <= 127 & pedigree <= 0.637 & glucose <= 111
      rule919  -0.30923561             glucose <= 154 & mass <= 26.3 & mass > 0
     rule1336  -0.25223230                        glucose <= 154 & mass <= 27.3
     rule1502  -0.23085816            glucose <= 154 & age <= 31 & mass <= 45.3
      rule649  -0.22895438           glucose <= 127 & pregnant <= 6 & age <= 34
     rule1627  -0.19438094        glucose <= 151 & mass <= 45.3 & pregnant <= 6
     rule1695  -0.19401622        glucose <= 123 & mass <= 45.3 & pregnant <= 6
      rule167  -0.19050765    glucose <= 132 & pedigree <= 0.821 & mass <= 38.8
      rule518   0.14996597                            glucose > 154 & age <= 57
     rule1344  -0.14887229    glucose <= 111 & pedigree <= 0.629 & mass <= 38.1
    rule9  -0.14325183   glucose <= 154 & mass <= 30.9 & pregnant <= 5
  (Intercept)   0.14008263                                                 <NA>
      rule992  -0.13422342     glucose <= 134 & glucose <= 107 & insulin <= 148
      rule416  -0.13358126                        glucose <= 143 & mass <= 28.8
      rule636  -0.09022366       glucose <= 154 & glucose <= 99 & insulin <= 86
        rule1  -0.02752370                        glucose <= 145 & mass <= 27.3
      rule335   0.02325657                          glucose > 123 & mass > 29.9
      rule832  -0.01011519        glucose <= 144 & pregnant <= 6 & mass <= 30.8
     rule1738  -0.00630662       glucose <= 154 & glucose <= 103 & mass <= 37.3

I think the issue is that all evaluates all elements plus the duplicate function is quite optimized. An issue with the chance I have made is that you might get more errors if there are changes in the partykit package which are harder to detect

The CRAN check does not pass at this point. The reason is that there needs to be added some imports. I ran devtools::document() after adding a few importFrom roxygen tags. Though, I got:

Warning: The existing 'NAMESPACE' file was not generated by roxygen2, and will not be overwritten.

Thus, I did not make any changes to the NAMESPACE file. The most tricky part of the changes was the predict.party. A few notes for my own sanity are:

#####
# Notes on predict.party

# It finds the end note for newdata by
#   fitted_node(node_party(object), data = newdata, 
#     vmatch = vmatch, perm = perm)
# Then it calls predict_party.constparty with id = fitted. This function uses
#   response <- party$fitted[["(response)"]]
#   weights <- party$fitted[["(weights)"]]
#   fitted <- party$fitted[["(fitted)"]]
# as the input to .predict_party_constparty to find the final value. It compute
# a weighted mean of observation over the reponses (for numeric outcomes). 
# However, this is redundant and we can do this prior when we only want 
# expected outcomes for terminal nodes of tress