dmlc / XGBoost.jl

XGBoost Julia Package
Other
289 stars 110 forks source link

Saving and reloading boosters using IOBuffer #202

Open gzarpapis opened 8 months ago

gzarpapis commented 8 months ago

Hello everyone. I'm trying to write my own keep_best_model_instance function, by saving my Booster when the validation loss decreases. I understand there is a specific function to predict on the best instance, but I need the model instance itself.

All my tries with save and load, serialize and deserialize with an IOBuffer have been unsuccessful and I could not find any documentation or examples on the topic.

If anyone has done this before, would you kindly guide how to achieve this? Also, maybe include this on the documentation?

Thank you for your time.

ExpandingMan commented 8 months ago

All my tries with save and load, serialize and deserialize with an IOBuffer have been unsuccessful and I could not find any documentation or examples on the topic.

Can you provide an MWE or, at the very least, an explicit error dump? I don't know what to make of this description. There are no model finalization steps, so as far as I know you should be able to call XGBoost.save(booster, io) at any time between calls to update! or updateone!.

gzarpapis commented 8 months ago

Here's my MWE. All attempts except 2, 3 and 4 throw error. My attempts 2, 3, 4 however give wrong result. Next comment will be my output.

    d_train = DMatrix(x_train, label=y_train);
    d_test = DMatrix(x_test, label=y_test);

    bst = Booster(d_train; max_depth = 20, XGBoost.regression(eval_metric = "mape")..., XGBoost.randomforest()...);
    best_loss = typemax(Float32);
    best_i = 0;
    buffer_save = IOBuffer(read=true, write=true);
    buffer_serial = IOBuffer(read=true, write=true);

    for i in 1:10
        update!(bst, d_train, watchlist = watch);
        pred = XGBoost.predict(bst, x_test);
        curr_loss = error(pred, y_test);
        println(curr_loss);
        if curr_loss < best_loss

            best_i = i;
            best_loss = curr_loss

            XGBoost.save(bst, buffer_save);
            buffer_serial = IOBuffer(XGBoost.serialize(bst));
            println(best_i);        
        end
    end

    assert_it_works = final -> (begin; pred = XGBoost.predict(final, x_test); curr_loss = error(pred, y_test); return (curr_loss, "Works: " * string(curr_loss == best_loss)); end);

    final_bst = Booster(DMatrix[]);
    try
        seekstart(buffer_save);
        final_bst = XGBoost.load(typeof(XGBoost.Booster));

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 1 Failed\n")
    end

    try
        seekstart(buffer_save);
        final_bst = XGBoost.Booster(DMatrix[], model_buffer = buffer_save);

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 2 Failed\n")
    end

    try
        seekstart(buffer_save);
        XGBoost.load!(final_bst, buffer_save);

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 3 Failed\n")
    end

    try
        seekstart(buffer_save);
        XGBoost.load!(final_bst, take!(buffer_save));

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 4 Failed\n")
    end

    try
        seekstart(buffer_save);
        XGBoost.load!(final_bst, buffer_save);

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 5 Failed\n")
    end

    try
        seekstart(buffer_serial);
        final_bst = XGBoost.deserialize!(final_bst, buffer_serial);

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 6 Failed\n")
    end

    try
        seekstart(buffer_serial);
        final_bst = XGBoost.deserialize!(final_bst, take!(buffer_serial));

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 7 Failed\n")
    end

    try
        seekstart(buffer_serial);
        final_bst = XGBoost.deserialize!(final_bst, read(buffer_serial));

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 8 Failed\n")
    end

    try
        final_bst = XGBoost.deserialize(buffer_serial);

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 9 Failed\n")
    end

    try
        final_bst = XGBoost.deserialize(Type{XGBoost.Booster}, take!(buffer_serial));

        println(assert_it_works(final_bst));
    catch e;
        #write(stdout, string(e));
        println("\nAttempt 10 Failed\n")
    end

    try
        final_bst = XGBoost.deserialize(typeof(final_bst), take!(buffer_serial));

        println(assert_it_works(final_bst));
    catch e;
        write(stdout, string(e));
        println("\nAttempt 11 Failed\n")
    end
gzarpapis commented 8 months ago
[ Info: [1]     train-mape:1.51966291478350146  valid-mape:1.54929523610775366
1.5492952
1
[ Info: [2]     train-mape:0.40160081844888185  valid-mape:0.46602336089583418
0.46602336
2
[ Info: [3]     train-mape:0.17652566519803181  valid-mape:0.26984500867009859
0.269845
3
[ Info: [4]     train-mape:0.12303823920340601  valid-mape:0.23781751415377347
0.23781751
4
[ Info: [5]     train-mape:0.10996602844796843  valid-mape:0.23461653884055558
0.23461653
5
[ Info: [6]     train-mape:0.10870173813734055  valid-mape:0.23464444234406565
0.23464446
[ Info: [7]     train-mape:0.10713973679701468  valid-mape:0.23546314605696411
0.23546316
[ Info: [8]     train-mape:0.10624772975655812  valid-mape:0.23561708124796898
0.23561707
[ Info: [9]     train-mape:0.10658877726512184  valid-mape:0.23602501327988257
0.236025
[ Info: [10]    train-mape:0.10552233684595115  valid-mape:0.23619752531929100
0.23619755
MethodError(XGBoost.load, (DataType,), 0x0000000000007be8)
Attempt 1 Failed

(1.5492952f0, "Works: false")
(1.5492952f0, "Works: false")
(1.5492952f0, "Works: false")
XGBoost.Lib.XGBoostError(XGBoost.Lib.XGBoosterLoadModelFromBuffer, "[10:17:59] /workspace/srcdir/xgboost/src/learner.cc:1003: Check failed: fi->Read(&mparam_, sizeof(mparam_)) == sizeof(mparam_) (0 vs. 136) : BoostLearner: wrong model format\nStack trace:\n  [bt] (0) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(+0x47fd34) [0x7fac29e07d34]\n  [bt] (1) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(xgboost::LearnerIO::LoadModel(dmlc::Stream*)+0x2b1) [0x7fac29e2a5a1]\n  [bt] (2) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(XGBoosterLoadModelFromBuffer+0x40) [0x7fac29b0e150]\n  [bt] (3) [0x7fac8f8c8fc7]\n  [bt] (4) [0x7fac8f8c9037]\n  [bt] (5) /home/zarpapis/julia-1.10.0/bin/../lib/julia/libjulia-internal.so.1.10(ijl_apply_generic+0x2ae) [0x7faca59f799e]\n  [bt] (6) [0x7fac8f86388e]\n  [bt] (7) [0x7fac8f892e77]\n  [bt] (8) [0x7fac8f897703]\n\n")
Attempt 5 Failed

MethodError(XGBoost.deserialize!, (Booster(), IOBuffer(data=UInt8[...], readable=true, writable=false, seekable=true, append=false, size=10387011, maxsize=Inf, ptr=1, mark=-1)), 0x0000000000007be8)
Attempt 6 Failed

XGBoost.Lib.XGBoostError(XGBoost.Lib.XGBoosterUnserializeFromBuffer, "[10:17:59] /workspace/srcdir/xgboost/src/learner.cc:1182: Check failed: header == serialisation_header_: If you are loading a serialized model (like pickle in Python, RDS in R) or\nconfiguration generated by an older version of XGBoost, please export the model by calling\n`Booster.save_model` from that version first, then load it back in current version. See:\n\n    https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html\n\nfor more details about differences between saving model and serializing.\n\nStack trace:\n  [bt] (0) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(+0x47fd34) [0x7fac29e07d34]\n  [bt] (1) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(xgboost::LearnerIO::Load(dmlc::Stream*)+0x287) [0x7fac29e2d177]\n  [bt] (2) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(XGBoosterUnserializeFromBuffer+0x48) [0x7fac29b0dbf8]\n  [bt] (3) [0x7fac8f8cb587]\n  [bt] (4) [0x7fac8f8cb5f7]\n  [bt] (5) /home/zarpapis/julia-1.10.0/bin/../lib/julia/libjulia-internal.so.1.10(ijl_apply_generic+0x2ae) [0x7faca59f799e]\n  [bt] (6) [0x7fac8f86388e]\n  [bt] (7) [0x7fac8f893852]\n  [bt] (8) [0x7fac8f897703]\n\n")
Attempt 7 Failed

XGBoost.Lib.XGBoostError(XGBoost.Lib.XGBoosterUnserializeFromBuffer, "[10:17:59] /workspace/srcdir/xgboost/src/learner.cc:1182: Check failed: header == serialisation_header_: If you are loading a serialized model (like pickle in Python, RDS in R) or\nconfiguration generated by an older version of XGBoost, please export the model by calling\n`Booster.save_model` from that version first, then load it back in current version. See:\n\n    https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html\n\nfor more details about differences between saving model and serializing.\n\nStack trace:\n  [bt] (0) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(+0x47fd34) [0x7fac29e07d34]\n  [bt] (1) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(xgboost::LearnerIO::Load(dmlc::Stream*)+0x287) [0x7fac29e2d177]\n  [bt] (2) /home/zarpapis/.julia/artifacts/271facf086d4d0b748a1835be4e1208876f382f9/lib/libxgboost.so(XGBoosterUnserializeFromBuffer+0x48) [0x7fac29b0dbf8]\n  [bt] (3) [0x7fac8f8cb587]\n  [bt] (4) [0x7fac8f8cb5f7]\n  [bt] (5) /home/zarpapis/julia-1.10.0/bin/../lib/julia/libjulia-internal.so.1.10(ijl_apply_generic+0x2ae) [0x7faca59f799e]\n  [bt] (6) [0x7fac8f86388e]\n  [bt] (7) [0x7fac8f893ff2]\n  [bt] (8) [0x7fac8f897703]\n\n")
Attempt 8 Failed

MethodError(XGBoost.deserialize, (IOBuffer(data=UInt8[...], readable=true, writable=false, seekable=true, append=false, size=10387011, maxsize=Inf, ptr=10387012, mark=-1),), 0x0000000000007be8)
Attempt 9 Failed

Attempt 10 Failed

MethodError(Booster, (), 0x0000000000007be8)
Attempt 11 Failed
ExpandingMan commented 8 months ago

So, when you call load it is providing the buffer as a parameter to the model constructor, however, when you call load! it's calling a different method called XGBoosterLoadModelFromBuffer. Confusingly, it seems that this is NOT the inverse of XGBoosterSaveModelToBuffer which is what we use to save. Confusion over the various serialization methods in libxgboost has popped up a number of times, and it seems it's still not fully resolved.

I have opened this issue in libxgboost to try to get some clarification. I agree that, regardless of what the maintainers say, the current state of load, load! and save functions in the Julia wrapper is extremely confusing (safe to say me and probably any other maintainers are themselves confused about the intended use of these functions) so I think we should either get those functions to work the way you are trying to use them, or we should deprecate load!.

In the meantime, I suggest you simply use save and load and not load. It's not clear to me there's any real disadvantage to this, as the overhead of creating the booster object in the first place is pretty low.

gzarpapis commented 8 months ago

After looking into it some more, I've realized a few things.

  1. I was using serialize wrong. It should look like this:
    
    # Make a new IOBuffer or ...
    # ... empty the buffer if reusing it in a loop
    b = IOBuffer(read = true, write = true);

Write to buffer

write(b, XGBoost.serialize(bst));

- - - - - Do other things - - - - -

Create a new Booster

final_bst = Booster(DMatrix[]);

Go to position 0 of buffer and deserialize

seekstart(b); deserialize!(final_bst, read(b)); # Or take!(b)



The `final_bst` object should now contain the proper instance of the model and not its latest. In my above example I was using the `IOBuffer` incorrectly, so don't pay attention to that.

2. `save` and `load` work with their string implementation for creating a file and reading from it, but I still don't know if it works with buffers.

3. Former functions are not even necessary if you don't care about a human-readable `json` format. What I mean is, the IOBuffer object can be written to and read by a file, no problem. So serializing - deserializing can be used for all purposes instead.