dmlc / XGBoost.jl

XGBoost Julia Package
Other
288 stars 110 forks source link

How to obtain training loss from a `Booster`? #128

Open scheidan opened 1 year ago

scheidan commented 1 year ago

During training the loss values are reported:

[ Info: XGBoost: starting training.
[ Info: [1]     train-rmse:1.01477459966132155
[ Info: [2]     train-rmse:0.90233272804303499
[ Info: [3]     train-rmse:0.81263379794430390
[ Info: [4]     train-rmse:0.72035539615691646
[ Info: [5]     train-rmse:0.65068651330880800
[ Info: [6]     train-rmse:0.60636982744130230
[ Info: [7]     train-rmse:0.56702425812618651
[ Info: [8]     train-rmse:0.52695104528071834
[ Info: [9]     train-rmse:0.49079542128198789
[ Info: [10]    train-rmse:0.46314743311883549
[ Info: Training rounds complete.

Is there a way to obtain the last trainings loss from a Booster object? Or maybe even the loss for all rounds?

Apologies if this is trivial, I just could not find anything in the documentation or in the code.

Thanks!

ExpandingMan commented 1 year ago

Unfortunately I don't think the libxgboost API provides any access to this. It probably doesn't even cache them. Of course if I turn out to be wrong about that it would be a nice thing to add.

scheidan commented 1 year ago

I just saw that it is possible with the R version, e.g.:

library(xgboost)

data(agaricus.train, package='xgboost')
test <- agaricus.test
mod <- xgboost(data = train$data, label = train$label,
               max.depth = 2, eta = 1, nthread = 2,
               nrounds = 10, objective = "binary:logistic")
mod$evaluation_log  # get's you a data.frame with the errors

Digging a bit into the R package I found that they get the log with cb.evaluation.log(), see https://github.com/dmlc/xgboost/blob/9dd8d70f0e00a214585f145a85c0c21ce6f132f1/R-package/R/callbacks.R#L106 However I do not understand how it works.

bobaronoff commented 1 year ago

Also digging through R version, specifically the callback and evaluation log routine. It appears to me that somehow they may be parsing the text going to std.out and accumulating results in evaluation log. Both the R and Python implementations make this data programmatically available. I have no idea how to do this, but the text is getting to the REPL somehow.

StevenWhitaker commented 1 year ago

I recently needed to grab the logged losses. Here is what I ended up doing:

using XGBoost, Logging

io = IOBuffer()
logger = SimpleLogger(io)
bst = with_logger(logger) do
    xgboost(...)
end
flush(io)
log = String(take!(io))

At this point log is a String with the following format (forgive the incorrect unicode characters; I'm not sure how to type them and I can't actually run the code right now to copy-paste them):

"""
/ Info: XGBoost: starting training.
\ @ XGBoost ...file.jl:line
/ Info: [1]     train-rmse:1.01477459966132155
\ @ XGBoost ...file.jl:line
/ Info: [2]     train-rmse:0.90233272804303499
\ @ XGBoost ...file.jl:line
/ Info: [3]     train-rmse:0.81263379794430390
\ @ XGBoost ...file.jl:line
/ Info: [4]     train-rmse:0.72035539615691646
\ @ XGBoost ...file.jl:line
/ Info: [5]     train-rmse:0.65068651330880800
\ @ XGBoost ...file.jl:line
/ Info: [6]     train-rmse:0.60636982744130230
\ @ XGBoost ...file.jl:line
/ Info: [7]     train-rmse:0.56702425812618651
\ @ XGBoost ...file.jl:line
/ Info: [8]     train-rmse:0.52695104528071834
\ @ XGBoost ...file.jl:line
/ Info: [9]     train-rmse:0.49079542128198789
\ @ XGBoost ...file.jl:line
/ Info: [10]    train-rmse:0.46314743311883549
\ @ XGBoost ...file.jl:line
/ Info: Training rounds complete.
\ @ XGBoost ...file.jl:line
"""

So then I parsed log as follows:

log_split = split(log, '\n')
log_loss = log_split[3:2:end-3] # Remove unnecessary lines
losses = map(log_loss) do l
    l_split = split(l)
    # `l_split[4]` is `"train-rmse:$loss_value"`,
    # so just need to find and parse `loss_value`.
    # The value starts at index 12 (since `"train-rmse:"` contains 11 characters)
    l_train = l_split[4][12:end]
    return parse(Float64, l_train)
end

The downside is this code only works for certain xgboost options, but hopefully this code can help anyone that needs a workaround.