scailable / sclblpy

Python package for Scailable uploads
MIT License
19 stars 1 forks source link

_model_is_fitted() and _predict for statsmodels #5

Closed MKaptein closed 4 years ago

MKaptein commented 4 years ago

The behavior of the statsmodels package for checking whether or not the model is fitted and for generating predictions seems inconsistent. See https://github.com/scailable/sclblpy/blob/develop/sclblpy/_utils.py lines 158 and 212. This should be checked (and should also be checked with the toolchain).

Note that we should be able to run all the tests in test_utils.py which covers these cases.

MKaptein commented 4 years ago

Check also whether the toolchain-server / the code in the docker is not affected.

MKaptein commented 4 years ago

So, line 212 should be:

result = mod.predict(feature_vector.reshape(1, -1))
return result.tolist()

And not the .fit() version it seems. Currently running more tests.

MKaptein commented 4 years ago

Ow, superb. Ik zie dat je:

    try:
        model_ok = _check_model(mod)
    except ModelSupportError as e:
        print("FATAL: The model you are trying to upload is not (yet) supported. \n"
              "Please see README.md for a list of supported models.")
        if _verbose:
            print(str(e))
        return False
    bundle['fitted_model'] = mod

De volledige check of het model gesupport is eruit hebt gesloopt.... :S.

MKaptein commented 4 years ago

Ok, the problem is actually simple:

For statsmodels we do:

estimator = sm.OLS(y, X)
mod = estimator.fit()

the question is, which one do we pass: estimator, or mod. My code worker for estimator, yours "works" for mod.

MKaptein commented 4 years ago

So, it all works (and already worked), as long as we use:

# for statsmodels:
mod = sm.OLS(y,X)
mod.fit()

upload(mod, ...)

And thus NOT:

# for statsmodels:
estimator = sm.OLS(y,X)
mod = estimator.fit()

upload(mod, ...)

I will add this to the readme as an example and close this issue.

robinvanemden commented 4 years ago

On further delving into statsmodels, it appears that, in contrast to SkLearn style processing:

mod = tree.DecisionTreeRegressor()
mod.fit(X, y)      
mod                            #<--- contains fitted model, client and server both happy!

... statsmodels only returns the fitted model:

mod = sm.OLS(y,X)
fitted = mod.fit()
mod                            # <--- does not contain fitted model, used client side
fitted                         # <--- contains fitted model, needed server side.

Happily, the fitted object incorporates the mod instance as fitted.model:

mod = sm.OLS(y,X)
fitted = mod.fit()

print( mod == fitted.model )   # <--- returns "True"

fitted.model                   # <--- use client side
fitted                         # <--- upload

... so there is no need for any major changes to the package, we only have to add .model to the applicable objects, and let users upload ie fitted, not fitted.model.

MKaptein commented 4 years ago

Why not just do mod.fit() at the server-side?

robinvanemden commented 4 years ago

That would be the other option. But it would mean that the model is fitted first client side, and then again server side. In general, our platform "runs fitted models". And it conforms with the way sklearn, XGBoost are processed (fitting client side, transpiling client side). Finally, it seems an easy client side fix. Nevertheless, it would also be an easy server side fix. My vote would be client side, but I am open to implementing this on the server side, if you strongly prefer that.

MKaptein commented 4 years ago

Ok, so this is now implemented such that we expect, for stats models, the following behavior:

    X, y = iris_data
    est = sm.OLS(y, X)
    mod = est.fit()
    docs = {'name': "OLS test"}
    fv = X[0, :]
    upload(mod, fv, docs=docs)

and hence an object called RegressionResultsWrapper is send to the toolchain-server (not the estimator itself).

This has affected the code in _utils.py and the test_utils.py is updated accordingly. Note that the only place where this is admittedly ugly is in the _model_supported() function; as this function check the name of the object to the models listed in supported_models.json there is an exception for statsmodels (ln 62):

        model_base: str = _get_model_package(obj)  # This returns a string OR raises an error
        # statsmodels hack for passing RegressionResultsWrapper:
        if model_base == "statsmodels":
            model_name: str = _get_model_name(obj.model)
        else:
            model_name: str = _get_model_name(obj)

I guess we can live with this :).