stanfordmlgroup / ngboost

Natural Gradient Boosting for Probabilistic Prediction
Apache License 2.0
1.64k stars 215 forks source link

can't pickle NGBSurvival #228

Closed m23shen closed 3 years ago

m23shen commented 3 years ago

I'm having issue saving my NGBSurvival model. Error below: PicklingError: Can't pickle <class 'ngboost.api.NGBSurvival.init..SurvivalDistn'>: it's not found as ngboost.api.NGBSurvival.init..SurvivalDistn

Seems that this issue is only for survival model, I tried to train a NGBRegressor and had no problem saving.

I'm using the most updated package with !pip install --upgrade git+https://github.com/stanfordmlgroup/ngboost.git

Please help, thanks!

MikeOMa commented 3 years ago

On a similar note for a different piece of code:

AttributeError: Can't pickle local object 'k_categorical.<locals>.Categorical'

occurs when I try to pickling a classifier I think

Here's an example which gives me the error

import ngboost
k = ngboost.distns.k_categorical(2)
pickle.dump(k, open("file.p", "wb"))

Maybe they are related?

alejandroschuler commented 3 years ago

These issues should be relatively easy to fix- the problem is that some of these distributions and other classes are generated on-the-fly so they can't be pickled. The solution is to write a custom pickling method so that instead of trying to save the distribution or whatever object it saves the inputs to the factory function, and then upon unpickling it rebuilds the object. See __setstate__ and __getstate__ in the current ngboost codebase here. The way I have it set up currently should in theory deal with the categorical case so I'm not totally sure what the wrinkle is. If someone could volunteer to have a go at this it would be a great help!!

MikeOMa commented 3 years ago

I can have a go at it.

I did find the k_categorical not pickling in the pip install ngboost version originally so the change you mentioned might've already fixed it if that is not in the 0.37 release.

ryan-wolbeck commented 3 years ago

@m23shen this should be fixed in the latest version, please re-open if you still see issues