dmlc / XGBoost.jl

XGBoost Julia Package
Other
288 stars 110 forks source link

Capturing evaluation log solved ..... sort of #148

Closed bobaronoff closed 1 year ago

bobaronoff commented 1 year ago

I and others have asked to be able to capture the evaluation log. My reasons involve how I was trained to build a tree boost model. Key in the workflow is to combine cross validation with a learning curve to optimize number of trees and assess for overfitting. I have written a routine to perform this task with multiple calls to predict and calculating the evaluation metrics explicitly. Although I've got this working nicely, it involves writing code specific for each objective and evaluation metric. It would be so nice (and more robust) to take advantage of the facilities in libxgboost.

I have written two functions. One creates the Booster and captures the evaluation log and the second parses the log in to a DataFrame. I am NO programmer. The first function is a cut and paste of several functions from booster.jl, the parse function is my personal hack.

This seems to be working on a few configurations that I've tested. The main function xgboost_log accepts the data as a DMatrix object. To be honest, I am confused by all the different ways to pass in data. I've started to explicitly create the DMatrix for the reason that is the only way I've been able to pass in a weight vector (am sure there are other ways but this is what worked for me). The function will also accept test data (as a DMatrix) via the keyword parameter: testdm.

Would love feedback as whether there is way to improve the functions. For those looking for this feature, feel free to use the code.

function xgboost_log(traindm::DMatrix, a...;
                        testdm::Any=[] ,
                        num_round::Integer=10,
                        kw...
                    )

    Xy = XGBoost.DMatrix(traindm)
    b = XGBoost.Booster(Xy; kw...)
    update_feature_names::Bool=false
    if typeof(testdm)==DMatrix
        watchlist=Dict("train"=>traindm , "test"=>testdm)
    else
        watchlist=Dict("train"=>traindm)
    end
    names = collect(Iterators.map(string, keys(watchlist)))
    watch = collect(Iterators.map(x -> x.handle, values(watchlist)))
    thelog = Vector{String}(undef,0)
    for j in 1:num_round
        XGBoost.xgbcall(XGBoost.XGBoosterUpdateOneIter, b.handle, j, Xy.handle)
        o = Ref{Ptr{Int8}}()
        XGBoost.xgbcall(XGBoost.XGBoosterEvalOneIter, b.handle, j, watch, names, length(watch), o)
        push!(thelog,unsafe_string(o[]))
        XGBoost._maybe_update_feature_names!(b, Xy, update_feature_names)
    end
    return (booster=b , log=parsethelog(thelog))
end

function parsethelog(thelog::Vector{String})
    nr=length(thelog)
    neval= length(findall(":",thelog[1]))
    cstr=split(replace(replace(thelog[1],"\t"=>","),":"=>","),",")
    evalnames=Vector{String}(undef,0)
    for c in 1:neval
        push!(evalnames, cstr[2*c])
    end
    vals=zeros(nr,neval)
    rnd=zeros(Int,nr)
    for r in 1:nr
        l1=replace(thelog[r],"\t"=>",")
        l2=replace(l1,":"=>",")
        l3=split(l2,",")
        rnd[r]= parse(Int64,SubString(l3[1],2,length(l3[1])-1))
        for c in 1:neval
            vals[r,c]= parse(Float64,l3[1+2*c])
        end
    end
    valdf=hcat(DataFrame(iteration=rnd),DataFrame(vals, evalnames))
    return valdf
end