ogrisel / pygbm

Experimental Gradient Boosting Machines in Python with numba.
MIT License
183 stars 32 forks source link

[MRG] Add support for classification tasks #13

Closed NicolasHug closed 6 years ago

NicolasHug commented 6 years ago

Udpate

This PR:


Outadted

One consequence of the new implementation is that the gradients array has to be recomputed (reallocated) at every iteration now. We cannot just simply update it like before (gradients += prediction_of_current_tree), because I don't think this would work for other losses.

LMK if this is going in the right direction and I'll try adding a classification loss and split GradientBoostingMachine into 2 subclasses.

:)

codecov-io commented 6 years ago

Codecov Report

Merging #13 into master will increase coverage by 0.32%. The diff coverage is 97.97%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master     #13      +/-   ##
=========================================
+ Coverage   94.17%   94.5%   +0.32%     
=========================================
  Files           8       9       +1     
  Lines         790     873      +83     
=========================================
+ Hits          744     825      +81     
- Misses         46      48       +2
Impacted Files Coverage Δ
pygbm/__init__.py 100% <100%> (ø) :arrow_up:
pygbm/loss.py 100% <100%> (ø)
pygbm/grower.py 89.2% <100%> (ø) :arrow_up:
pygbm/plotting.py 100% <100%> (ø) :arrow_up:
pygbm/gradient_boosting.py 86.5% <93.93%> (+0.68%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 76445b1...792777c. Read the comment docs.

ogrisel commented 6 years ago

Also please add the deviance / logistic loss for classification to this PR with tests to ensure that the design makes sense.

NicolasHug commented 6 years ago

FYI now the higgs boson benchmarks gives:

python benchmarks/bench_higgs_boson.py --n-trees 500 --learning-rate 0.1 --n-leaf-nodes 255 --subsample 1000000
Model Time acc
LightGBM 73s 0.7488
pygbm 92s 0.7423
ogrisel commented 6 years ago

FYI now the higgs boson benchmarks gives: ...

Nice! But do you have any idea why the classification loss is so slow in LightGBM? Do you do something special in your implementation?

ogrisel commented 6 years ago

Is pygbm always faster? Even with the full dataset on a machine with more than 4 CPU cores?

NicolasHug commented 6 years ago

I can only test with my own machine (4 cores) but:

I'll try to re-use the gradients and hessians arrays and compute them with numba and keep you updated about any improvement

NicolasHug commented 6 years ago

But do you have any idea why the classification loss is so slow in LightGBM? Do you do something special in your implementation?

Is pygbm always faster?

On my machine pygbm is slower. Is it faster on yours?

ogrisel commented 6 years ago

On my machine pygbm is slower. Is it faster on yours?

Sorry, I read the other way round. That would have been surprising :)

ogrisel commented 6 years ago

one reason pygbm is slower than lightgbm is probably because gradients and hessians are updated inplace in lightgbm, while they are reallocated at each iteration in pygbm.

I doubt that allocation time in the outer (gradient boosting) loop has any impact on the runtime.

There is also the fact that in pygbm, hessians in logistic loss need the gradients, so they are computed twice. I made a comment in the code about this, we might want to use a get_gradients_and_hessians method to avoid it.

+1, this is a better design.

ogrisel commented 6 years ago

We might also want to have a single method that computes both the objective value, the gradients and the hessians at once. For some losses there might be some computation to mutualise.

NicolasHug commented 6 years ago

I doubt that allocation time in the outer (gradient boosting) loop has any impact on the runtime.

You're probably right. One other major difference that I missed is that gradients and hessians computation is parallelized in LightGBM. That's probably a big low-hanging fruit for us... but as jitclass doesn't support @njit(parallel) that would make the code a bit ugly (plus, there's no built-in static scheduling).

I'd say we do it anyway? We could keep get_gradients_slow() and get_hessians_slow() to have a clean version of the methods.

ogrisel commented 6 years ago

I'd say we do it anyway? We could keep get_gradients_slow() and get_hessians_slow() to have a clean version of the methods.

At least it worth check whether or not this is causing a significant performance difference.

NicolasHug commented 6 years ago

With the parallelized code + reuse of the gradients/hessians arrays the above benchmark runs in 86s (vs 92s) before. I have consistent results over 3 different runs.

Not much but definitely worth it IMO.

ogrisel commented 6 years ago

With the parallelized code + reuse of the gradients/hessians arrays the above benchmark runs in 86s (vs 92s) before. I have consistent results over 3 different runs.

Not much but definitely worth it IMO.

Still worth it indeed. Too bad that numba jitclass methods do not support parallel=True.

NicolasHug commented 6 years ago

I removed the get_gradients() method but kept the tests intact with a helper.

I'm not sure what to do about the other tests...

ogrisel commented 6 years ago

@NicolasHug There are some broken test in (caused by a low level assertion error in splitting.py:246). This is probably caused by a low level change in the threading system in numba 0.41.0: http://numba.pydata.org/numba-doc/latest/release-notes.html

I am not sure whether it's a regression in numba or a bug in our code that is revealed by stricter runtime checks in numba.

ogrisel commented 6 years ago

I confirm I can reproduce the issue when using numba 0.41.0 from PyPI. I will try with conda next.

ogrisel commented 6 years ago

I just saw you already found a workaround in #51. Will review and merge.

NicolasHug commented 6 years ago

I addressed the comments.

Haven't done any progress regarding the newton test :/ I tried with float64 but I'm still getting the convergence error

NicolasHug commented 6 years ago

Is it possible that the newton test isn't working with logistic loss just because the logistic loss can't be zero?

NicolasHug commented 6 years ago

I added a numerical test for the hessians as well.

Honestly I don't fully understand what the newton test is supposed to be doing and I'm having a hard time trying to understand why it fails.

If the ultimate goal is to make sure the gradients and hessians values are correct, then test_gradients_and_hessians_values is doing exactly this, and it does not rely on any additional routine with potential side effects that we don't control.

ogrisel commented 6 years ago

I think we should make it explicit in the loss API that y_pred from the trees is actually a raw score and needs to go through the (inverse) link function before being able to interpret it and compare it to y_train. For least squares, y_pred and y_train are homogeneous and can be directly compared but this is not the case for other loss function in general. This is probably the cause of the newton test failure. I will try to dig a bit further.

ogrisel commented 6 years ago

Maybe we could rename y_pred to predicted_scores internally and always have:

y_pred = self.loss_.inverse_link_function(predicted_scores)
NicolasHug commented 6 years ago

Yes I agree 100% I was planning to submit a PR about this. But I don't think this causes the issue because y_pred in __call__ is homogeneous to that in update_gradients_and_hessians.

Else the numerical comparison test would fail.

ogrisel commented 6 years ago

To me the logistic loss should be:

    def __call__(self, y_true, predicted_raw_scores, average=True):
        z_pred = y_true * predicted_raw_scores
        loss = np.logaddexp(0., -z_pred)
        return loss.mean() if average else loss

and it should be very close to zero when y_true=[1.] and predicted_raw_scores=[30.] or more.

Also I have a doubt about the use of args=[y_true] for the newton method. I think the args, therefore loss will be called with loss(optim_param, y_true) instead of loss(y_true, optim_param).

If we want to pass fprime and fprime2 the order of the hards of the helper functions also need to be consistent.

ogrisel commented 6 years ago

i think this is it. I am working on a fix for that test. Give me 5min and I will push a commit.

ogrisel commented 6 years ago

Ok my loss functions seems to break the numerical gradient check while yours passes...

ogrisel commented 6 years ago

I messed up with git and lost my changes. I have to go offline for approximately two hours. Feel free to takeover if you wish. Otherwise I will give it another try.

NicolasHug commented 6 years ago

Won't be there either ^^

NicolasHug commented 6 years ago

I'm not sure if I understand your loss formulation. Is it for a -1 / 1 encoding?

The one I'm using is from The Elements of Statistical Learning section 4.4.1

logistic loss = - log likelihood = sum_x log(proba that x belongs to its true class)

The proba that x belongs to class 1 is defined as p_1(x) = 1 - p_0(x) = sigmoid(y_pred) where y_pred is the raw values from the tree.

A bit a calculus leads to the unifying expression used in the code:

log(proba that x belongs to its true class) = log(1 + e(y_pred)) - y * y_pred

ogrisel commented 6 years ago

I have the fix working. I will push it.

ogrisel commented 6 years ago

I proposed my loss from memory without checking but obviously it's wrong because the numerical gradient test did fail. Yours is correct.

ogrisel commented 6 years ago

The warnings in the newton test should disappear with the upcoming scipy 1.2.0 release: https://github.com/scipy/scipy/pull/8907.

ogrisel commented 6 years ago

@NicolasHug Indeed your analysis is right, the expression I used is only valid when y_true takes values in {-1, 1} instead of {0, 1}.

ogrisel commented 6 years ago

I am not sure that "logistic" is a good name for the binary classification loss. Maybe "binomial_nll", "binomial_nll" or "binary_crossentropy" would be fine? The later is the name used in Keras.

"Logistic sigmoid" or "expit" is the name of the inverse link function, not the loss function.

ogrisel commented 6 years ago

In scikit-learn we use "deviance" but I think this name is a bit ambiguous and seems to refer to different things in statistics.

ogrisel commented 6 years ago

For the multiclass softmax based loss, I think "categorical_crossentropy" or "categorical_nll" would be fine.

ogrisel commented 6 years ago

We could also alias "gaussian_nll" or "gaussian_crossentropy" to the least squares loss for consistency.

NicolasHug commented 6 years ago

Thanks a lot for fixing the test!

+1 for binary_crossentropy and categorical_crossentropy

How about also accepting crossentropy which would result in either one depending on the nature of the problem?

I'll push something tomorrow or on Monday

ogrisel commented 6 years ago

Technically, least squares is also cross-entropy aka NLL with a Gaussian model resulting in the identity link function in Generalized Linear Model parlance.

It's true that it would be convenient to have a default value that switches between the binomial and categorical model depending on the number of unique classes observed in the training set. However I would rather put loss="auto" in the constructor and do the switch in the fit method of the classifier.

NicolasHug commented 6 years ago

Thanks a lot for the help @ogrisel I'll merge once the checks go green and go on with #49 after rebasing