This focuses here is on speed ups but there are also some bug fixes
Alle the results below are from running:
> summary(microbenchmark::microbenchmark(
+ pre = airq.ens <- pre(Ozone ~ ., data=airquality[complete.cases(airquality),]),
+ times = 100))
The initial result was:
expr min lq mean median uq max neval
1 pre 4.018802 4.171205 4.33873 4.295319 4.407155 6.901721 100
After commit 6f298c8:
expr min lq mean median uq max neval
1 pre 3.57509 3.670837 3.86512 3.793877 3.998536 4.700677 100
After commit d237fb1:
expr min lq mean median uq max neval
1 pre 2.312484 2.355866 2.455929 2.40926 2.498872 3.090584 100
After commit 7b008b0:
expr min lq mean median uq max neval
1 pre 1.804114 1.90265 2.039485 1.97219 2.083083 3.855207 100
The present result is:
expr min lq mean median uq max neval
1 pre 1.376896 1.448779 1.524643 1.487653 1.582981 2.037187 100
Further reduction in computation time can be achieved by looking at the varies calls to the functions in the partykit package. There are a lot of validations in these functions that makes them slow. Though, the biggest computational cost is in the training of the trees. The partykit package is a pure R code so I suspect a that compiled version of CTREE like in the party package could reduce computation time further. The part for classification with learning rates have not been optimized. Though, a similar step to that performed to partykit::ctree can be made
I removed the ! from !(....) and updated the tests. It reduced the CV error in the example you send from:
mean cv error (se) = 339.6366 (92.21531)
... to:
mean cv error (se) = 299.798 (74.28342)
All the test are made against the updated version. The current version of the removal of complements is quite slow when there are a lot of rules. I updated it to a new version which is still slow. Thus, I have implemented an alternative. You can get it by calling with use_suggestion = TRUE. If you do not like it then remove (in any case remove the comment!):
# TODO: remove this if you do not want the suggestion I propose
if(use_suggestion && removeduplicates && removecomplements){
...
To illustrate that the new approach is faster and as good then consider these calls:
> data(PimaIndiansDiabetes, package = "mlbench")
> nrow(PimaIndiansDiabetes) # a bit larger
[1] 768
> set.seed(seed <- 9602165)
> system.time(
+ f1 <- pre(diabetes ~ ., data = PimaIndiansDiabetes, maxdepth = 3,
+ use_suggestion = FALSE))
user system elapsed
44.84 1.17 46.47
> f1
Final ensemble with cv error within 1se of minimum:
lambda = 0.01991541
number of terms = 20
mean cv error (se) = 0.9591347 (0.03588186)
rule coefficient description
rule82 0.62514441 glucose > 127 & mass > 29.9
rule1180 -0.41465507 glucose <= 154 & pregnant <= 7 & age <= 28
rule580 -0.35054073 glucose <= 127 & pedigree <= 0.637 & glucose <= 111
rule919 -0.30923561 glucose <= 154 & mass <= 26.3 & mass > 0
rule1336 -0.25223230 glucose <= 154 & mass <= 27.3
rule1502 -0.23085816 glucose <= 154 & age <= 31 & mass <= 45.3
rule649 -0.22895438 glucose <= 127 & pregnant <= 6 & age <= 34
rule1627 -0.19438094 glucose <= 151 & mass <= 45.3 & pregnant <= 6
rule1695 -0.19401622 glucose <= 123 & mass <= 45.3 & pregnant <= 6
rule167 -0.19050765 glucose <= 132 & pedigree <= 0.821 & mass <= 38.8
rule518 0.14996597 glucose > 154 & age <= 57
rule1344 -0.14887229 glucose <= 111 & pedigree <= 0.629 & mass <= 38.1
rule9 -0.14325183 glucose <= 154 & mass <= 30.9 & pregnant <= 5
(Intercept) 0.14008263 <NA>
rule992 -0.13422342 glucose <= 134 & glucose <= 107 & insulin <= 148
rule416 -0.13358126 glucose <= 143 & mass <= 28.8
rule636 -0.09022366 glucose <= 154 & glucose <= 99 & insulin <= 86
rule1 -0.02752370 glucose <= 145 & mass <= 27.3
rule335 0.02325657 glucose > 123 & mass > 29.9
rule832 -0.01011519 glucose <= 144 & pregnant <= 6 & mass <= 30.8
rule1738 -0.00630662 glucose <= 154 & glucose <= 103 & mass <= 37.3
> set.seed(seed)
> system.time(
+ f2 <- pre(diabetes ~ ., data = PimaIndiansDiabetes, maxdepth = 3,
+ use_suggestion = TRUE))
user system elapsed
16.39 0.06 16.86
> f2
Final ensemble with cv error within 1se of minimum:
lambda = 0.01991541
number of terms = 20
mean cv error (se) = 0.9591347 (0.03588186)
rule coefficient description
rule82 0.62514441 glucose > 127 & mass > 29.9
rule1180 -0.41465507 glucose <= 154 & pregnant <= 7 & age <= 28
rule580 -0.35054073 glucose <= 127 & pedigree <= 0.637 & glucose <= 111
rule919 -0.30923561 glucose <= 154 & mass <= 26.3 & mass > 0
rule1336 -0.25223230 glucose <= 154 & mass <= 27.3
rule1502 -0.23085816 glucose <= 154 & age <= 31 & mass <= 45.3
rule649 -0.22895438 glucose <= 127 & pregnant <= 6 & age <= 34
rule1627 -0.19438094 glucose <= 151 & mass <= 45.3 & pregnant <= 6
rule1695 -0.19401622 glucose <= 123 & mass <= 45.3 & pregnant <= 6
rule167 -0.19050765 glucose <= 132 & pedigree <= 0.821 & mass <= 38.8
rule518 0.14996597 glucose > 154 & age <= 57
rule1344 -0.14887229 glucose <= 111 & pedigree <= 0.629 & mass <= 38.1
rule9 -0.14325183 glucose <= 154 & mass <= 30.9 & pregnant <= 5
(Intercept) 0.14008263 <NA>
rule992 -0.13422342 glucose <= 134 & glucose <= 107 & insulin <= 148
rule416 -0.13358126 glucose <= 143 & mass <= 28.8
rule636 -0.09022366 glucose <= 154 & glucose <= 99 & insulin <= 86
rule1 -0.02752370 glucose <= 145 & mass <= 27.3
rule335 0.02325657 glucose > 123 & mass > 29.9
rule832 -0.01011519 glucose <= 144 & pregnant <= 6 & mass <= 30.8
rule1738 -0.00630662 glucose <= 154 & glucose <= 103 & mass <= 37.3
I think the issue is that all evaluates all elements plus the duplicate function is quite optimized. An issue with the chance I have made is that you might get more errors if there are changes in the partykit package which are harder to detect
The CRAN check does not pass at this point. The reason is that there needs to be added some imports. I ran devtools::document() after adding a few importFrom roxygen tags. Though, I got:
Warning: The existing 'NAMESPACE' file was not generated by roxygen2, and will not be overwritten.
Thus, I did not make any changes to the NAMESPACE file. The most tricky part of the changes was the predict.party. A few notes for my own sanity are:
#####
# Notes on predict.party
# It finds the end note for newdata by
# fitted_node(node_party(object), data = newdata,
# vmatch = vmatch, perm = perm)
# Then it calls predict_party.constparty with id = fitted. This function uses
# response <- party$fitted[["(response)"]]
# weights <- party$fitted[["(weights)"]]
# fitted <- party$fitted[["(fitted)"]]
# as the input to .predict_party_constparty to find the final value. It compute
# a weighted mean of observation over the reponses (for numeric outcomes).
# However, this is redundant and we can do this prior when we only want
# expected outcomes for terminal nodes of tress
This focuses here is on speed ups but there are also some bug fixes
Alle the results below are from running:
The initial result was:
After commit 6f298c8:
After commit d237fb1:
After commit 7b008b0:
The present result is:
Further reduction in computation time can be achieved by looking at the varies calls to the functions in the
partykit
package. There are a lot of validations in these functions that makes them slow. Though, the biggest computational cost is in the training of the trees. Thepartykit
package is a pure R code so I suspect a that compiled version of CTREE like in theparty
package could reduce computation time further. The part for classification with learning rates have not been optimized. Though, a similar step to that performed topartykit::ctree
can be madeI don't get the not here:
I removed the
!
from!(....)
and updated the tests. It reduced the CV error in the example you send from:... to:
All the test are made against the updated version. The current version of the removal of complements is quite slow when there are a lot of rules. I updated it to a new version which is still slow. Thus, I have implemented an alternative. You can get it by calling with
use_suggestion = TRUE
. If you do not like it then remove (in any case remove the comment!):To illustrate that the new approach is faster and as good then consider these calls:
I think the issue is that
all
evaluates all elements plus theduplicate
function is quite optimized. An issue with the chance I have made is that you might get more errors if there are changes in thepartykit
package which are harder to detectThe CRAN check does not pass at this point. The reason is that there needs to be added some imports. I ran
devtools::document()
after adding a fewimportFrom
roxygen tags. Though, I got:Thus, I did not make any changes to the
NAMESPACE
file. The most tricky part of the changes was thepredict.party
. A few notes for my own sanity are: