Closed juanramonua closed 4 years ago
Good catch... I looked into it and this is happening because the way predict
is implemented right now is a bit silly and only works for distributions that have a loc
parameter.
The way to fix this, imo, is to define a predict
method for all distributions that returns the conditional mean (i.e. the loc
parameter for those we've implemented) for continuous distributions and the conditional mode (i.e. the most likely class) for discrete distributions. We can probably even use the underlying scipy distributions for this. I'll do that sometime soon.
In the meantime, you can easily get the predicted class by doing ngb.pred_dist(X_test).prob > 0.5
.
Ok, I've fixed this issue with https://github.com/stanfordmlgroup/ngboost/commit/c05485187f1952b62797499ddbcc0b67618550aa. Please do play around with it and let me know if this now works for you and/or if you want other features.
Running your code from https://github.com/stanfordmlgroup/ngboost/blob/master/examples/classification.py I received the following error when I try
ngb.predict(X_test)
`--------------------------------------------------------------------------- AttributeError Traceback (most recent call last)