marjoleinF / pre

an R package for deriving Prediction Rule Ensembles
58 stars 17 forks source link

More general sparse prediction ensemble with bug fixes #6

Closed boennecd closed 7 years ago

boennecd commented 7 years ago

Main addition

Added the gpe function to make more general sparse prediction ensemble. The advantage is that:

1) It is easy to generalize to other methods of getting basis function 2) The implementation is modular so it is easy to test each part 3) An option to uses basis function from products of hinge function through the MARS method is added. It uses the earth package

Here is an example of the new function is given below. We use the Boston housing example also used in Friedman & Popescu (2008):

library(pre)
data(BostonHousing, package = "mlbench")

#####
# Current pre function
set.seed(seed <- 7134103)
fit <- pre(
  medv ~ ., data= BostonHousing)

fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.06058199
#>  number of terms = 89
#>  mean cv error (se) = 9.266844 (1.05969) 
#>
#>         rule    coefficient                                      description
#>  (Intercept)  26.2415519380                                             <NA>
#>     rule1448   5.1595444757                     lstat <= 14.1 & dis <= 1.413
#>     rule1136   2.6062947579   lstat <= 15.02 & dis <= 2.4259 & lstat <= 9.53
#>      rule977   2.5536183701                     ptratio <= 17.8 & rm > 7.412
#>     rule1473   2.0347582098                     crim <= 18.4982 & b > 100.63
#>     rule1538  -2.0274777750                       lstat > 3.32 & nox > 0.668
#>      rule802   1.8979863706     rm <= 7.416 & lstat <= 9.93 & crim > 1.42502
#>      rule820   1.6996688985         ptratio <= 19.6 & rm > 6.538 & rm > 7.42
#>      rule397   1.6969034641                    lstat <= 5.21 & dis <= 3.1992
#>     rule1121   1.5249533052                     ptratio <= 19.2 & rm > 7.135
#>      rule442   1.4558173083                    lstat <= 9.69 & lstat <= 4.63
#>     rule1590  -1.3811742722                           rad <= 8 & rm <= 7.765
#>     rule1115   1.3571204247                   lstat <= 9.64 & crim > 4.54192
#>      rule858   1.3426190577              rm > 6.635 & tax <= 403 & rm > 7.47
#>     rule1241  -1.3184774411      lstat > 5.39 & chas %in% c("0") & tax > 222
#>     rule1014  -1.0574678789                    lstat > 9.62 & crim > 5.69175
#>     rule1419   1.0363067008        rm <= 6.824 & nox <= 0.51 & dis <= 4.4377
#>     rule1431   1.0039372488                       ptratio <= 19 & tax <= 289
#>     rule1280   0.9726751526           lstat <= 13.34 & rad <= 8 & rm > 7.041
#>     rule1267  -0.9350508111                     ptratio > 14.4 & lstat > 5.1
#>     rule1092  -0.9297643319                    lstat > 3.95 & ptratio > 19.2
#>     rule1030  -0.9273113538     lstat > 11.66 & crim > 6.96215 & nox > 0.668
#>     rule1533   0.8692635936                     ptratio <= 19.6 & rm > 6.635
#>      rule997   0.7900331984                       lstat <= 5.21 & rm > 6.854
#>      rule870  -0.7741811543        rm <= 6.824 & lstat > 14.37 & nox > 0.524
#>     rule1405   0.7555178473      lstat <= 9.53 & crim <= 1.51902 & rm > 6.54
#>     rule1574  -0.7349460457                         tax > 222 & ptratio > 16
#>     rule1268   0.7235344119                 crim <= 18.811 & ptratio <= 20.9
#>     rule1428   0.7129514604                     ptratio <= 18.7 & rm > 6.897
#>     rule1192   0.6996357883                      ptratio <= 18.7 & rm > 6.77
#>      rule258  -0.6866779323         lstat > 5.12 & lstat > 14.79 & tax > 469
#>      rule687  -0.6521046019         rm <= 7.007 & lstat > 16.03 & age > 84.7
#>     rule1393  -0.6507141624         ptratio <= 19.2 & rm <= 7.47 & tax > 222
#>      rule946   0.6500854252  lstat <= 9.93 & ptratio > 15.2 & crim > 0.06417
#>     rule1444  -0.6127256522                                     lstat > 9.93
#>      rule704   0.5652217441                       rm > 6.939 & ptratio <= 17
#>      rule221  -0.5578130221                     lstat > 9.54 & lstat > 19.78
#>     rule1557  -0.5368797856         crim <= 5.29305 & rad <= 8 & rm <= 6.606
#>      rule716  -0.5319602397    lstat > 9.54 & crim <= 18.811 & lstat > 19.78
#>     rule1339   0.5234864513                                       rm > 6.545
#>      rule794  -0.5169004549       lstat > 5.1 & crim > 9.51363 & nox > 0.668
#>      rule884   0.5055781816                   lstat <= 9.88 & crim > 0.51183
#>     rule1521   0.5032169345                                       rm > 6.041
#>      rule748  -0.4912615920         rm <= 6.824 & lstat > 14.37 & age > 73.1
#>     rule1249  -0.4899802374                      lstat > 3.95 & ptratio > 13
#>     rule1447   0.4545461647                     ptratio <= 19.6 & rm > 7.007
#>       rule10   0.3789574568     lstat <= 9.54 & rm <= 7.007 & crim > 0.05425
#>     rule1074   0.3644603057                                    lstat <= 7.73
#>     rule1352  -0.3634765829      ptratio > 17.8 & crim <= 7.40389 & rad <= 4
#>      rule617  -0.3221147704       lstat > 5.12 & lstat > 14.36 & nox > 0.597
#>     rule1138  -0.3157577996      lstat <= 15.02 & dis > 2.4259 & rm <= 6.072
#>     rule1517  -0.3077847710                chas %in% c("0") & ptratio > 16.6
#>      rule807  -0.3048722137        rm <= 6.833 & lstat > 12.01 & b <= 378.08
#>     rule1438   0.3002165892                                       rm > 6.142
#>      rule578  -0.2573953842   rm <= 6.939 & lstat <= 14.69 & crim <= 4.55587
#>       rule41  -0.2517522537     lstat > 9.69 & lstat > 14.98 & dis <= 2.4358
#>     rule1539  -0.2348244841                           rm <= 6.163 & rad <= 8
#>      rule868   0.2294730498           rm <= 6.824 & lstat <= 14.37 & rad > 7
#>      rule843  -0.2174452970                    lstat > 4.59 & ptratio > 19.6
#>     rule1565  -0.2100790698                       ptratio > 19.6 & rm > 4.97
#>     rule1043  -0.2075915064        ptratio <= 18.7 & rm <= 7.416 & tax > 222
#>     rule1097  -0.2041425378        lstat > 5.12 & lstat > 11.25 & age > 73.1
#>     rule1229  -0.1699250081         rm <= 6.849 & lstat > 11.69 & age > 70.2
#>     rule1169  -0.1698078186                    lstat > 9.59 & crim > 7.40389
#>      rule723  -0.1685469041                    lstat > 9.93 & crim > 5.82115
#>     rule1420  -0.1612472364         rm <= 6.824 & nox <= 0.51 & dis > 4.4377
#>     rule1432  -0.1518339426             ptratio <= 19 & tax > 289 & rad <= 4
#>       rule22  -0.1353304583         lstat > 4.86 & lstat > 14.36 & tax > 469
#>     rule1279  -0.1333709544          lstat <= 13.34 & rad <= 8 & rm <= 7.041
#>     rule1349   0.1128400183                     ptratio <= 17.8 & tax <= 224
#>      rule523   0.1087033028       lstat > 8.93 & lstat <= 21.46 & b > 100.63
#>     rule1253  -0.1063203331                                   ptratio > 17.8
#>      rule606  -0.0973060832     lstat > 5.12 & lstat > 9.93 & crim > 5.29305
#>        lstat  -0.0882873315                        3.1225 <= lstat <= 29.945
#>      rule952   0.0873806838                                       rm > 6.849
#>     rule1252   0.0806475986                      ptratio <= 17.8 & rm > 7.61
#>     rule1370  -0.0730056355                         rm <= 7.82 & rm <= 5.841
#>     rule1588   0.0725094750    lstat <= 10.13 & crim <= 1.22358 & rm > 6.122
#>     rule1327   0.0589262949                     ptratio <= 19.6 & tax <= 296
#>     rule1218  -0.0470914150  rm <= 7.691 & chas %in% c("0") & ptratio > 19.7
#>      rule730  -0.0393811911                       rm <= 6.606 & lstat > 17.6
#>      rule981  -0.0377467519       lstat <= 11.65 & rm <= 7.42 & dis > 1.9865
#>      rule847   0.0337284169         ptratio <= 19.2 & rm > 6.718 & rm > 7.61
#>     rule1347   0.0202459420                     ptratio <= 19.2 & tax <= 296
#>      rule535  -0.0021213743        rm <= 6.833 & lstat > 16.03 & nox > 0.597
#>     rule1312   0.0009228270         lstat <= 7.56 & rad > 2 & crim <= 0.5405
#>          age  -0.0005724327                             9.9625 <= age <= 100
#>     rule1386  -0.0003278398                chas %in% c("0") & ptratio > 19.2
#>      rule230   0.0002922414                   lstat <= 9.69 & crim > 0.51183
#>     rule1345  -0.0001945838                ptratio > 15.6 & chas %in% c("0")

#####
# We pass the following to get a similar model
set.seed(seed)
fit <- gpe(
  medv ~ ., data= BostonHousing, 
  # Additionally we specify the learners though this is the default
  base_learners = list(gpe_linear(), gpe_tress()))

fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.06058199
#>  number of terms = 89
#>  mean cv error (se) = 9.301042 (1.055479) 
#>
#>                                             description    coefficient
#>                                             (Intercept)  26.2642262751
#>               lTerm(age, lb = 10, ub = 100, scale = 28)  -0.0291336322
#>            lTerm(lstat, lb = 3.1, ub = 30, scale = 6.9)  -1.5755272295
#>                rTerm(chas %in% c("0") & ptratio > 16.6)  -0.3017956309
#>                rTerm(chas %in% c("0") & ptratio > 19.2)  -0.0002459567
#>                     rTerm(crim <= 18.4982 & b > 100.63)   2.0302190072
#>                 rTerm(crim <= 18.811 & ptratio <= 20.9)   0.7289432657
#>         rTerm(crim <= 5.29305 & rad <= 8 & rm <= 6.606)  -0.5263773387
#>    rTerm(lstat <= 10.13 & crim <= 1.22358 & rm > 6.122)   0.0702232792
#>       rTerm(lstat <= 11.65 & rm <= 7.42 & dis > 1.9865)  -0.0478096075
#>          rTerm(lstat <= 13.34 & rad <= 8 & rm <= 7.041)  -0.1455067815
#>           rTerm(lstat <= 13.34 & rad <= 8 & rm > 7.041)   1.0135225627
#>                     rTerm(lstat <= 14.1 & dis <= 1.413)   5.1767687879
#>   rTerm(lstat <= 15.02 & dis <= 2.4259 & lstat <= 9.53)   2.6068239404
#>      rTerm(lstat <= 15.02 & dis > 2.4259 & rm <= 6.072)  -0.3149370605
#>                    rTerm(lstat <= 5.12 & dis <= 3.1992)   1.6972313953
#>                       rTerm(lstat <= 5.21 & rm > 6.854)   0.7849146664
#>         rTerm(lstat <= 7.56 & rad > 2 & crim <= 0.5405)   0.0001685414
#>                                    rTerm(lstat <= 7.73)   0.3571833469
#>      rTerm(lstat <= 9.53 & crim <= 1.51902 & rm > 6.54)   0.7395970198
#>     rTerm(lstat <= 9.54 & rm <= 7.007 & crim > 0.05425)   0.3747728677
#>                   rTerm(lstat <= 9.59 & crim > 0.51183)   0.0023040884
#>                   rTerm(lstat <= 9.64 & crim > 4.54192)   1.3327289859
#>                    rTerm(lstat <= 9.69 & lstat <= 4.63)   1.4515752812
#>                   rTerm(lstat <= 9.88 & crim > 0.51183)   0.4945665925
#>  rTerm(lstat <= 9.93 & ptratio > 15.2 & crim > 0.06417)   0.6515158354
#>     rTerm(lstat > 11.66 & crim > 6.96215 & nox > 0.668)  -0.8971317610
#>                       rTerm(lstat > 3.32 & nox > 0.668)  -2.0361949703
#>                      rTerm(lstat > 3.95 & ptratio > 13)  -0.4943443907
#>                    rTerm(lstat > 3.95 & ptratio > 19.2)  -0.9238231817
#>      rTerm(lstat > 4.45 & lstat > 7.56 & lstat > 19.78)  -0.5511724467
#>                    rTerm(lstat > 4.59 & ptratio > 19.6)  -0.1769374057
#>         rTerm(lstat > 4.86 & lstat > 14.36 & tax > 469)  -0.1139645574
#>       rTerm(lstat > 5.1 & crim > 9.51363 & nox > 0.668)  -0.5207536258
#>        rTerm(lstat > 5.12 & lstat > 11.25 & age > 73.1)  -0.2071782044
#>       rTerm(lstat > 5.12 & lstat > 14.36 & nox > 0.597)  -0.3277211975
#>         rTerm(lstat > 5.12 & lstat > 14.79 & tax > 469)  -0.6952136021
#>     rTerm(lstat > 5.12 & lstat > 9.93 & crim > 5.29305)  -0.1083298467
#>      rTerm(lstat > 5.39 & chas %in% c("0") & tax > 222)  -1.3222477264
#>       rTerm(lstat > 8.93 & lstat <= 21.46 & b > 100.63)   0.1052940003
#>    rTerm(lstat > 9.54 & crim <= 18.811 & lstat > 19.78)  -0.5265462791
#>                    rTerm(lstat > 9.59 & crim > 7.40389)  -0.1849320761
#>                    rTerm(lstat > 9.62 & crim > 5.69175)  -1.0649872443
#>     rTerm(lstat > 9.69 & lstat > 14.98 & dis <= 2.4358)  -0.2495621775
#>                                     rTerm(lstat > 9.88)  -0.0012949493
#>                                     rTerm(lstat > 9.93)  -0.6117792981
#>                    rTerm(lstat > 9.93 & crim > 5.82115)  -0.1557704040
#>                     rTerm(ptratio <= 17.8 & rm > 7.412)   2.5379468826
#>                      rTerm(ptratio <= 17.8 & rm > 7.61)   0.0991394149
#>                     rTerm(ptratio <= 17.8 & tax <= 224)   0.1224968061
#>        rTerm(ptratio <= 18.7 & rm <= 7.416 & tax > 222)  -0.2044595441
#>                      rTerm(ptratio <= 18.7 & rm > 6.77)   0.7069801567
#>                     rTerm(ptratio <= 18.7 & rm > 6.897)   0.7121850961
#>                       rTerm(ptratio <= 19 & tax <= 289)   0.9981686758
#>             rTerm(ptratio <= 19 & tax > 289 & rad <= 4)  -0.1546460103
#>         rTerm(ptratio <= 19.2 & rm <= 7.47 & tax > 222)  -0.6277358596
#>         rTerm(ptratio <= 19.2 & rm > 6.718 & rm > 7.61)   0.0144862310
#>                     rTerm(ptratio <= 19.2 & rm > 7.135)   1.4980512595
#>                      rTerm(ptratio <= 19.2 & rm > 7.47)   1.3426696146
#>                     rTerm(ptratio <= 19.2 & tax <= 296)   0.0100276029
#>         rTerm(ptratio <= 19.6 & rm > 6.538 & rm > 7.42)   1.7292299549
#>                     rTerm(ptratio <= 19.6 & rm > 6.635)   0.8718481420
#>                     rTerm(ptratio <= 19.6 & rm > 7.007)   0.4194070377
#>                     rTerm(ptratio <= 19.6 & tax <= 296)   0.0710338569
#>                     rTerm(ptratio > 14.4 & lstat > 5.1)  -0.9388892390
#>                rTerm(ptratio > 15.6 & chas %in% c("0"))  -0.0034435347
#>                                   rTerm(ptratio > 17.8)  -0.1030656850
#>      rTerm(ptratio > 17.8 & crim <= 7.40389 & rad <= 4)  -0.3610229352
#>                       rTerm(ptratio > 19.6 & rm > 4.97)  -0.2409379074
#>                           rTerm(rad <= 8 & rm <= 7.765)  -1.3742012055
#>                           rTerm(rm <= 6.163 & rad <= 8)  -0.2375083662
#>                       rTerm(rm <= 6.606 & lstat > 17.6)  -0.0263932575
#>           rTerm(rm <= 6.824 & lstat <= 14.37 & rad > 7)   0.2380238678
#>         rTerm(rm <= 6.824 & lstat > 14.37 & age > 73.1)  -0.4953490822
#>        rTerm(rm <= 6.824 & lstat > 14.37 & nox > 0.524)  -0.7679429732
#>        rTerm(rm <= 6.824 & nox <= 0.51 & dis <= 4.4377)   1.0391455157
#>         rTerm(rm <= 6.824 & nox <= 0.51 & dis > 4.4377)  -0.1569652794
#>        rTerm(rm <= 6.833 & lstat > 12.01 & b <= 378.08)  -0.3060049918
#>         rTerm(rm <= 6.849 & lstat > 11.69 & age > 70.2)  -0.1745522292
#>   rTerm(rm <= 6.939 & lstat <= 14.69 & crim <= 4.55587)  -0.2616050983
#>         rTerm(rm <= 7.007 & lstat > 16.03 & age > 84.7)  -0.6501284428
#>     rTerm(rm <= 7.416 & lstat <= 9.93 & crim > 1.42502)   1.8957507244
#>  rTerm(rm <= 7.691 & chas %in% c("0") & ptratio > 19.7)  -0.0389296990
#>                         rTerm(rm <= 7.82 & rm <= 5.841)  -0.0733464076
#>                                       rTerm(rm > 6.041)   0.5071732832
#>                                       rTerm(rm > 6.142)   0.2924334755
#>                                       rTerm(rm > 6.545)   0.5406203299
#>                                       rTerm(rm > 6.849)   0.0880494966
#>                       rTerm(rm > 6.939 & ptratio <= 17)   0.5601240358
#>                         rTerm(tax > 222 & ptratio > 16)  -0.7425775733
#####
# We can also add a MARS model (with the earth package)
set.seed(seed)
fit <- gpe(
  medv ~ ., data= BostonHousing,
  # we add gpre_earth to get hinge prod functions from earth
  base_learners = list(gpe_tress(), gpe_linear(), 
                       gpe_earth()))

fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.05029623
#>  number of terms = 98
#>  mean cv error (se) = 8.638106 (0.9926755) 
#>
#>                                                                               description    coefficient
#>                                                                               (Intercept)  21.8846033967
#>                                                       rTerm(b > 354.31 & ptratio <= 20.9)   0.2337546029
#>                                                       rTerm(crim <= 18.4982 & b > 100.63)   1.2391770271
#>                                                   rTerm(crim <= 18.811 & ptratio <= 20.9)   0.8078818876
#>                                           rTerm(crim <= 5.29305 & rad <= 8 & rm <= 6.606)  -0.0259516509
#>                                            rTerm(lstat <= 13.34 & rad <= 8 & rm <= 7.041)  -0.0811895267
#>                                       rTerm(lstat <= 14.81 & crim <= 4.83567 & rm > 6.77)   0.2075844182
#>                                      rTerm(lstat <= 15.02 & dis > 1.6132 & dis <= 4.6947)   0.5890612013
#>                                        rTerm(lstat <= 15.02 & dis > 2.4259 & rm <= 6.072)  -0.3862386895
#>                                                                      rTerm(lstat <= 4.69)   0.5508015319
#>                                           rTerm(lstat <= 7.56 & rad > 2 & crim <= 0.5405)   0.4214638385
#>                                        rTerm(lstat <= 9.53 & crim <= 1.51902 & rm > 6.54)   0.2520888776
#>                                       rTerm(lstat <= 9.54 & rm <= 7.007 & crim > 0.05425)   0.4165462453
#>                                                      rTerm(lstat <= 9.69 & lstat <= 4.63)   0.5172905065
#>                                                         rTerm(lstat <= 9.69 & rm > 6.635)   0.2112135652
#>                                                     rTerm(lstat <= 9.88 & crim > 0.51183)   0.5057516393
#>                                       rTerm(lstat <= 9.93 & crim <= 1.83377 & rm > 6.069)   0.0731324279
#>                                         rTerm(lstat <= 9.93 & crim <= 2.924 & rm > 6.606)   0.3839695240
#>                                    rTerm(lstat <= 9.93 & ptratio > 15.2 & crim > 0.06417)   0.4341448736
#>                                                                      rTerm(lstat > 10.13)  -0.3562904431
#>                                       rTerm(lstat > 11.25 & crim > 7.40389 & nox > 0.668)  -0.3786367449
#>                                       rTerm(lstat > 11.66 & crim > 6.96215 & nox > 0.668)  -0.9259503679
#>                                                         rTerm(lstat > 3.32 & nox > 0.668)  -1.0928010200
#>                                        rTerm(lstat > 4.45 & lstat > 7.56 & lstat > 19.78)  -0.5309836729
#>                                            rTerm(lstat > 5.04 & tax <= 432 & rm <= 7.079)  -0.4687954049
#>                                         rTerm(lstat > 5.1 & crim > 9.51363 & nox > 0.668)  -0.2088083630
#>                                         rTerm(lstat > 5.12 & lstat > 14.36 & nox > 0.597)  -0.0200574259
#>                                        rTerm(lstat > 5.39 & chas %in% c("0") & tax > 222)  -0.6300964408
#>                                      rTerm(lstat > 5.39 & lstat > 11.69 & crim > 0.54452)  -0.0268514534
#>                                       rTerm(lstat > 5.39 & lstat > 14.8 & crim > 0.55778)  -0.3275990990
#>                                        rTerm(lstat > 7.6 & lstat > 14.37 & crim > 0.7842)  -0.1465612691
#>                                      rTerm(lstat > 7.73 & crim <= 15.288 & lstat > 19.78)  -0.0516351688
#>                                         rTerm(lstat > 8.93 & lstat <= 21.46 & b > 100.63)   0.2039669761
#>                                        rTerm(lstat > 9.53 & crim > 5.82401 & nox > 0.668)  -0.4319108579
#>                                      rTerm(lstat > 9.54 & crim <= 18.811 & lstat > 19.78)  -0.6959327046
#>                                                      rTerm(lstat > 9.62 & crim > 5.69175)  -0.4228034259
#>                                       rTerm(lstat > 9.69 & lstat > 14.98 & dis <= 2.4358)  -0.6353703838
#>                                                       rTerm(ptratio <= 17.8 & rm > 7.412)   1.6652318165
#>                                                        rTerm(ptratio <= 18.6 & rm > 7.47)   0.5683803575
#>                                          rTerm(ptratio <= 18.7 & rm <= 7.416 & tax > 222)  -0.5228257040
#>                                                        rTerm(ptratio <= 18.7 & rm > 6.77)   0.2597554800
#>                                                       rTerm(ptratio <= 18.7 & rm > 6.897)   0.8467098554
#>                                                         rTerm(ptratio <= 19 & tax <= 289)   0.6716810189
#>                                               rTerm(ptratio <= 19 & tax > 289 & rad <= 4)  -0.1849259002
#>                                                        rTerm(ptratio <= 19.2 & rm > 6.54)   0.4051521652
#>                                                       rTerm(ptratio <= 19.2 & rm > 7.135)   0.7253939528
#>                                                       rTerm(ptratio <= 19.2 & tax <= 296)   0.1907046791
#>                                                    rTerm(ptratio <= 19.6 & dis <= 3.2797)   0.0006554801
#>                                                       rTerm(ptratio <= 19.6 & tax <= 296)   0.0013088439
#>                                                       rTerm(ptratio <= 19.7 & tax <= 277)   0.1448952251
#>                                                       rTerm(ptratio > 14.4 & lstat > 5.1)  -0.3527262618
#>                                                                     rTerm(ptratio > 17.8)  -0.7061203383
#>                                                         rTerm(ptratio > 19.6 & rm > 4.97)  -0.3475416951
#>                                                           rTerm(rm <= 6.593 & age > 73.4)  -0.7122770855
#>                                            rTerm(rm <= 6.824 & lstat <= 14.37 & rad <= 7)  -0.2582394616
#>                                             rTerm(rm <= 6.824 & lstat <= 14.37 & rad > 7)   0.2179381941
#>                                          rTerm(rm <= 6.824 & lstat > 14.37 & nox > 0.524)  -0.8134375570
#>                                          rTerm(rm <= 6.824 & nox <= 0.51 & dis <= 4.4377)   0.6691561913
#>                                                           rTerm(rm <= 6.824 & nox > 0.51)  -0.3111650726
#>                                          rTerm(rm <= 6.833 & lstat > 12.01 & b <= 378.08)  -1.0093716262
#>                                          rTerm(rm <= 6.833 & lstat > 16.03 & nox > 0.597)  -0.4604477597
#>                                           rTerm(rm <= 6.849 & lstat > 11.69 & age > 70.2)  -0.1736455491
#>                                           rTerm(rm <= 7.007 & lstat > 16.03 & age > 84.7)  -0.3314445061
#>                                                                         rTerm(rm > 6.041)   0.4724929891
#>                                                                         rTerm(rm > 6.142)   0.2734159402
#>                                                           rTerm(tax > 222 & ptratio > 16)  -0.9334752088
#>                          eTerm(pmax(nox - 0.614, 0) * pmax(rm - 5.985, 0), scale = 0.022)  -0.2222769184
#>                        eTerm(pmax(rm - 6.372, 0) * pmax(ptratio - 18.9, 0), scale = 0.21)  -0.0289299751
#>                         eTerm(pmax(14.4208 - crim, 0) * pmax(rm - 5.709, 0), scale = 8.5)   1.1075621151
#>                            eTerm(pmax(1.7821 - dis, 0) * pmax(395.18 - b, 0), scale = 11)   0.2794133542
#>                        eTerm(pmax(rm - 6.431, 0) * pmax(ptratio - 18.3, 0), scale = 0.29)  -0.0195561886
#>                        eTerm(pmax(3.2628 - dis, 0) * pmax(5.98 - lstat, 0), scale = 0.57)   0.6404726236
#>                               eTerm(pmax(rm - 5.747, 0) * pmax(287 - tax, 0), scale = 27)   0.3632560883
#>                         eTerm(pmax(rm - 6.404, 0) * pmax(18.6 - ptratio, 0), scale = 1.4)   0.0469702200
#>        eTerm((chas == "1") * pmax(rm - 6.404, 0) * pmax(ptratio - 18.6, 0), scale = 0.18)  -1.2043061683
#>                            eTerm(pmax(2.2834 - dis, 0) * pmax(b - 369.3, 0), scale = 5.4)  -1.1245372217
#>                           eTerm(pmax(tax - 300, 0) * pmax(lstat - 18.68, 0), scale = 950)  -0.3256758619
#>                           eTerm(pmax(rm - 6.167, 0) * pmax(6.43 - lstat, 0), scale = 1.2)   0.5941025912
#>                         eTerm(pmax(dis - 2.9634, 0) * pmax(lstat - 6.12, 0), scale = 8.8)  -0.7042369751
#>  eTerm(pmax(2.3158 - dis, 0) * pmax(329.46 - b, 0) * pmax(18.06 - lstat, 0), scale = 110)  -0.9808408298
#>                                                    eTerm(pmax(127.36 - b, 0), scale = 23)  -0.0548319567
#>                              eTerm(pmax(281 - tax, 0) * pmax(7.2 - lstat, 0), scale = 41)   0.0620776149
#>                        eTerm(pmax(dis - 2.741, 0) * pmax(18.6 - ptratio, 0), scale = 4.2)  -0.2306348037
#>      eTerm(pmax(crim - 6.96215, 0) * (chas == "1") * pmax(25.79 - lstat, 0), scale = 1.2)   0.6316434437
#>                         eTerm(pmax(rm - 6.286, 0) * pmax(19.2 - ptratio, 0), scale = 1.8)   1.9747927217
#>                              eTerm(pmax(rm - 6.286, 0) * pmax(80.8 - age, 0), scale = 15)   1.8149618438
#>                           eTerm(pmax(304 - tax, 0) * pmax(ptratio - 18.4, 0), scale = 22)  -0.1643944708
#>                        eTerm(pmax(3.3317 - dis, 0) * pmax(12.12 - lstat, 0), scale = 2.3)   1.6651938154
#>                          eTerm(pmax(rm - 6.897, 0) * pmax(23.29 - lstat, 0), scale = 5.3)   0.0010902261
#>   eTerm(pmax(1.8172 - dis, 0) * pmax(tax - 304, 0) * pmax(19.69 - lstat, 0), scale = 310)   4.0626159076
#>   eTerm(pmax(2.7147 - dis, 0) * pmax(b - 378.35, 0) * pmax(23.69 - lstat, 0), scale = 37)  -0.3840885808
#>                                                eTerm(pmax(crim - 18.811, 0), scale = 5.2)  -0.1347250123
#>   eTerm(pmax(18.811 - crim, 0) * pmax(1.7883 - dis, 0) * pmax(tax - 307, 0), scale = 290)   0.9790009935
#>                             eTerm(pmax(tax - 666, 0) * pmax(lstat - 8.01, 0), scale = 62)  -0.7577164387
#>                           eTerm(pmax(254 - tax, 0) * pmax(20.31 - lstat, 0), scale = 160)   0.3761311861
#>                       eTerm(pmax(crim - 7.02259, 0) * pmax(2.2625 - dis, 0), scale = 5.1)  -0.3032953115
#>                           eTerm(pmax(rm - 6.833, 0) * pmax(22.6 - lstat, 0), scale = 5.4)   2.2328395325
#>                         eTerm(pmax(2.7778 - dis, 0) * pmax(22.6 - lstat, 0), scale = 4.2)   0.0269969968
#>                          eTerm(pmax(rm - 6.333, 0) * pmax(21.78 - lstat, 0), scale = 7.7)   0.7320707333

The new function is documented so further information about the setup is in ?gpe after installation e.g. by calling:

devtools::install_github("boennecd/pre@b5d54f64a0243083b07e5b458b54f72f541df881")

The se on the error seems to be smaller in the (few) data sets have looked when the MARS terms are used. See commit fe984578a2195804a8bd7c920c0e17ce52267b0a for another example

I can add an example with gpe to the readme if you want

Other major changes

CRAN and tests

I fixed the issues with tests that did not pass with devtools::check(document = FALSE, cleanup = FALSE). The cause and solution is described in commit 4f4bc47d86a9da7fa3694b0fff4056b5273e7601

I strongly suggest to keep the tests with the builds. This will cause them also to run on CRAN which can discover bugs on different platforms. The new version does include the tests in the build

README

I have changed the formatting of the readme using knitr. You can see the new readme with plots on https://github.com/boennecd/pre/tree/b5d54f64a0243083b07e5b458b54f72f541df881

Subsampling

Notice that gpe_sample does make replacement when sample_frac < 1 and the weights are unequal. The motivation is that the weights will not be representative otherwise. I have not changed the setup in pre but suggest to do similarly as in gpe_sample

Factors

Fixed bugs with factor levels when the learning rate is postive in pre. See commit aebf7d6af1f7582d743408f771f84d9b05f9d262

Removing complements

Changed the method to remove complement rules to a considerably faster methods that provides the same result. See commit 23eb1cd7829d3e76af86e20ead83c57e005b46ee

Other remarks

The partykit::glmtree is terribly slow for larger data sets. I have not found a solution yet but have posted some options in commit fdc23a2b38afb50f1e87d6a3b07ea1122856a023

Minor changes

bsnullmodcall will change the verbose argument of $call to the argument passed to bsnullmodcall Fixed typos using devtools::spell_check Updated test that did not pass when I downloaded the new version of the package. See commit fdba929c08b56dbb38015b10cb98e7c6fa384022

boennecd commented 7 years ago

I have made a few additions. Sorry for doing this so shortly after I opened the request. The changes are:

Here is an example of the changes. The Sonar data set is binary. Notice both the computation time and mean cv error. The computation using gradient boosting may provide the answer to the comments I pose in fdc23a2b38afb50f1e87d6a3b07ea1122856a023

#####
# Install version
devtools::install_github(
  "boennecd/pre@34845cb30fb3f20847f0c8eca076d878eebd6ad9")

# Load and prepare data
data(Sonar, package = "mlbench")
dat <- Sonar

#####
# Fit w/ model as in Friedman & Popescu (2008) using partykit::glmtree
set.seed(seed <- 4455668)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100), gpe_linear())))
#>    user  system elapsed 
#>      12       0      12 
options(digits = 4) # used to alter print
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.02065
#>   number of terms = 20
#>   mean cv error (se) = 0.6228 (0.06798) 
#> 
#> [output abbreviated]

options(digits = 2) # use fewer digits
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.021
#>   number of terms = 20
#>   mean cv error (se) = 0.62 (0.068) 
#> 
#> [output abbreviated]

#####
# Fit model w/ MARS terms w/ trees using partykit::glmtree 
set.seed(seed)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100), gpe_linear(), gpe_earth())))
#> Beware that gpe_earth will use gradiant boosting
#>    user  system elapsed 
#>   13.98    0.01   14.00 
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.017
#>   number of terms = 27
#>   mean cv error (se) = 0.56 (0.043) 
#> 
#> [output abbreviated]

#####
# Fit model w/ MARS terms and trees using partykit::ctree w/ gradient boosting
set.seed(seed)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE), 
                       gpe_linear(), gpe_earth())))
#> Beware that gpe_earth will use gradiant boosting
#>    user  system elapsed 
#>    5.00    0.02    5.02 
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.011
#>   number of terms = 43
#>   mean cv error (se) = 0.59 (0.042) 
#> 
#> [output abbreviated]

#####
# Fit model w/ MARS terms and trees using partykit::ctree w/ gradient boosting
# w/ higher learning rate
set.seed(seed)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE, learnrate = .5), 
                       gpe_linear(), 
                       gpe_earth(learnrate = .5))))
#> Beware that gpe_earth will use gradiant boosting
#>    user  system elapsed 
#>     4.8     0.0     4.8 
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.01
#>   number of terms = 45
#>   mean cv error (se) = 0.52 (0.06) 
#> 
#> [output abbreviated]

#####
# Fit model w/ MARS terms and trees using partykit::ctree w/ gradient boosting
# w/ higher learning rate and elastic net instead L1 yielding a less sparse
# solution
set.seed(seed)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE, learnrate = .5), 
                       gpe_linear(), 
                       gpe_earth(learnrate = .5)), 
  penalized_trainer = gpe_cv.glmnet(alpha = 0.5)))
#> Beware that gpe_earth will use gradiant boosting
#>    user  system elapsed 
#>    5.02    0.03    5.40 
fit
#> 
#> Final ensemble with cv error within 1se of minimum: 
#>   lambda =  0.0095
#>   number of terms = 164
#>   mean cv error (se) = 0.42 (0.055) 
#> 
#> [output abbreviated]
boennecd commented 7 years ago

I added a feature to screen out terms from the MARS model using pairwise correlations. The motivation is that the previous options scales poorly when more basis functions are added. It is similar to the remove_duplicates_complements argument for gpe_tress.

A value of cor_thresh < 1 as the argument to gpe_earth implies that the later basis function in pairwise correlation greater than cor_thresh is removed. A value greater than 1 or NULL yields no removal. There are some examples below. The performance is close and the computation time decreases. The difference becomes greater if more models are added from the MARS model.

# Ex 1
data(BostonHousing, package = "mlbench")

invisible(gc())
set.seed(seed <- 6560036)
system.time(fit <- gpe(
  medv ~ ., data= BostonHousing,
  base_learners = list(gpe_tress(), gpe_linear(), 
                       gpe_earth(cor_thresh = NULL)
  )))
#>   user  system elapsed 
#>  12.42    0.18   12.59 
fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.07355997
#>  number of terms = 86
#>  mean cv error (se) = 10.83148 (2.424318) 
#>
#> [output abbreviated]

invisible(gc())
set.seed(seed)
system.time(fit <- gpe(
  medv ~ ., data= BostonHousing,
  base_learners = list(gpe_tress(), gpe_linear(), 
                       gpe_earth())))
#>   user  system elapsed 
#>  10.89    0.05   10.94 
fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.07355997
#>  number of terms = 81
#>  mean cv error (se) = 10.99771 (2.440741) 
#>
#> [output abbreviated]

# Ex 2
data(Sonar, package = "mlbench")
dat <- Sonar

invisible(gc())
set.seed(seed <- 4455668)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE), 
                       gpe_linear(), 
                       gpe_earth(cor_thresh = NULL)), 
  penalized_trainer = gpe_cv.glmnet(alpha = 0.5)))
#>Beware that gpe_earth will use gradient boosting
#>   user  system elapsed 
#>   4.24    0.01    4.25 
fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.01395031
#>  number of terms = 142
#>  mean cv error (se) = 0.4975857 (0.04412764) 
#>
#> [output abbreviated]

invisible(gc())
set.seed(seed)
system.time(fit <- gpe(
  Class ~ ., data = dat,
  base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE), 
                       gpe_linear(), 
                       gpe_earth()), 
  penalized_trainer = gpe_cv.glmnet(alpha = 0.5)))
#>Beware that gpe_earth will use gradient boosting
#>   user  system elapsed 
#>   3.96    0.02    3.97 
fit
#>
#>Final ensemble with cv error within 1se of minimum: 
#>  lambda =  0.01331624
#>  number of terms = 120
#>  mean cv error (se) = 0.4973539 (0.04179383) 
#>
#> [output abbreviated]

# Microbenchmark on Ex 2
set.seed(seed)
summary(microbenchmark::microbenchmark(
  no_remove = {
    set.seed(seed)
    gpe(
      Class ~ ., data = dat,
      base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE), 
                           gpe_linear(), 
                           gpe_earth(cor_thresh = NULL)), 
      penalized_trainer = gpe_cv.glmnet(alpha = 0.5))
  },

  with_remove = {
    set.seed(seed)
    gpe(
      Class ~ ., data = dat,
      base_learners = list(gpe_tress(ntrees = 100, use_grad = TRUE), 
                           gpe_linear(), 
                           gpe_earth()), 
      penalized_trainer = gpe_cv.glmnet(alpha = 0.5))
  },

  times = 5))
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>Beware that gpe_earth will use gradient boosting
#>         expr      min       lq     mean   median       uq      max neval cld
#>1   no_remove 4.378992 4.601111 4.657146 4.658193 4.771663 4.875772     5   b
#>2 with_remove 3.936837 4.020951 4.078190 4.062532 4.128391 4.242240     5  a