Closed liuyanguu closed 5 years ago
Hi liuyanguu, thanks for your kind words!
You can define the tuning space and thus the booster type by changing the par.set
used for optimization in autoxgboost
.
The default is autoxgbparset
autoxgbparset = makeParamSet(
makeNumericParam("eta", lower = 0.01, upper = 0.2),
makeNumericParam("gamma", lower = -7, upper = 6, trafo = function(x) 2^x),
makeIntegerParam("max_depth", lower = 3, upper = 20),
makeNumericParam("colsample_bytree", lower = 0.5, upper = 1),
makeNumericParam("colsample_bylevel", lower = 0.5, upper = 1),
makeNumericParam("lambda", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("alpha", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("subsample", lower = 0.5, upper = 1)
)
But there is also autoxgbparset.mixed
predefined, which tunes over gbtree, gblinear and dart.
autoxgbparset.mixed = makeParamSet(
makeDiscreteParam("booster", values = c("gbtree", "gblinear", "dart")),
makeDiscreteParam("sample_type", values = c("uniform", "weighted"), requires = quote(booster == "dart")),
makeDiscreteParam("normalize_type", values = c("tree", "forest"), requires = quote(booster == "dart")),
makeNumericParam("rate_drop", lower = 0, upper = 1, requires = quote(booster == "dart")),
makeNumericParam("skip_drop", lower = 0, upper = 1, requires = quote(booster == "dart")),
makeLogicalParam("one_drop", requires = quote(booster == "dart")),
makeDiscreteParam("grow_policy", values = c("depthwise", "lossguide")),
makeIntegerParam("max_leaves", lower = 0, upper = 8, trafo = function(x) 2^x, requires = quote(grow_policy == "lossguide")),
makeIntegerParam("max_bin", lower = 2L, upper = 9, trafo = function(x) 2^x),
makeNumericParam("eta", lower = 0.01, upper = 0.2),
makeNumericParam("gamma", lower = -7, upper = 6, trafo = function(x) 2^x),
makeIntegerParam("max_depth", lower = 3, upper = 20),
makeNumericParam("colsample_bytree", lower = 0.5, upper = 1),
makeNumericParam("colsample_bylevel", lower = 0.5, upper = 1),
makeNumericParam("lambda", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("alpha", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("subsample", lower = 0.5, upper = 1)
)
But you can also easily create a custom dataset yourself to tune over that solely uses dart. (Or just remove the other two options in the booster parameter of the mixed set)
Hope this helps.
Thank you very much Janek! This is super helpful! Indeed we can choose to only run dart by setting par.set
to be:
par_dart <- makeParamSet(
makeNumericParam("eta", lower = 0.01, upper = 0.2),
makeNumericParam("gamma", lower = -7, upper = 6, trafo = function(x) 2^x),
makeIntegerParam("max_depth", lower = 3, upper = 20),
makeNumericParam("colsample_bytree", lower = 0.5, upper = 1),
makeNumericParam("colsample_bylevel", lower = 0.5, upper = 1),
makeNumericParam("lambda", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("alpha", lower = -10, upper = 10, trafo = function(x) 2^x),
makeNumericParam("subsample", lower = 0.5, upper = 1),
makeDiscreteParam("booster", values = c("dart")),
makeDiscreteParam("sample_type", values = c("uniform", "weighted"), requires = quote(booster == "dart")),
makeDiscreteParam("normalize_type", values = c("tree", "forest"), requires = quote(booster == "dart")),
makeNumericParam("rate_drop", lower = 0, upper = 1, requires = quote(booster == "dart")),
makeNumericParam("skip_drop", lower = 0, upper = 1, requires = quote(booster == "dart")),
makeLogicalParam("one_drop", requires = quote(booster == "dart"))
)
or tune over the three booster types by assigning par.set = autoxgbparset.mixed
. Need to load the autoxgbparset.mixed
in advance though. Very convenient.
Interestingly, in the end gbtree was chosen as the result from dart was very close in terms of tuning mse.
Thanks again for your kind help.
Yang
Hi Janek, thanks a lot for this wonderful package. I really enjoy testing it in my projects. I am wondering how can I change the booster type from 'gbtree' to 'dart' (in a regression task)? Really appreciate your input.