oracle / tribuo

Tribuo - A Java machine learning library
https://tribuo.org
Apache License 2.0
1.26k stars 173 forks source link

XGBoostClassificationTrainer tuning parameters #238

Closed emilyBerg2022 closed 2 years ago

emilyBerg2022 commented 2 years ago

I am running into some issues while using XGBoost algorithm to solve a classification problem. in particular, for tuning XGBoost parameters: The approach I want to use is to change one parameter at a time and then run the training process while monitoring the validation values. This means I use random search optimization process to find out best parmaters for : eta, then gamma, then max_depth, then min_child_weight, then subsample, colsample_bytree and so on. Finally, I select the combination of parameters, which gives the best validation performance.

Th problem when I implement this with class (https://tribuo.org/learn/4.2/javadoc/org/tribuo/classification/xgboost/XGBoostClassificationTrainer.html#%3Cinit%3E(int,java.util.Map)), I couldn't pass tuned parameters to constructor XGBoostClassificationTrainer(int numTrees, Map<String,Object> parameters), in particular to parmater hashmap. In documentation it s mentionned that This gives direct access to the XGBoost parameter map. what happened is paremeters I set manually (while tuning) are overwritten with default values in class XGBoostTrainer; probably because of postConfig() method that is called in every constructor in this class where the custom parameters are forwarded. what I mean is postConfig() ( found here https://github.com/oracle/tribuo/blob/98ed939d96259e58bea1b02adce18caa22fce7fe/Classification/XGBoost/src/main/java/org/tribuo/classification/xgboost/XGBoostClassificationTrainer.java) looks like:

public void postConfig() { 
parameters.put("eta", eta); 
parameters.put("gamma", gamma); 
parameters.put("max_depth", maxDepth);
parameters.put("min_child_weight", minChildWeight); 
parameters.put("subsample", subsample); 
parameters.put("colsample_bytree", featureSubsample);
parameters.put("lambda", lambda);
parameters.put("alpha", alpha); 
parameters.put("nthread", nThread); 
parameters.put("seed", seed); 
if (silent == 1) { 
parameters.put("verbosity", 0); 
}
else { 
parameters.put("verbosity", verbosity.value); 
} 
parameters.put("booster", booster.paramName); 
parameters.put("tree_method", treeMethod.paramName); 
}

would you please shed some light over this issue, and point me to what I am doing wrong, so that I can manually tune XGBOOST parmaters with class in question. Your help on this matter is appreciated. Thanks in advance.

Craigacp commented 2 years ago

Yep, that's a bug on our end. At some point in the past that constructor didn't call postConfig which is why that doc was there, and we'll fix it so that postConfig is no longer called. In the case where the user supplies parameters we should copy those out into the provenance information rather than copy from the fields into the parameters.

Craigacp commented 2 years ago

This bug is fixed in main and now also correctly records the overriding parameters in the provenance object.