gallantlab / himalaya

Multiple-target linear models - CPU/GPU
https://gallantlab.github.io/himalaya
BSD 3-Clause "New" or "Revised" License
80 stars 13 forks source link

AttributeError: 'ColumnKernelizer' object has no attribute 'force_cpu' when reloading a model #18

Closed mvdoc closed 3 years ago

mvdoc commented 3 years ago

Before the force_cpu logic was implemented, I joblib.dumped a model. Now after I load the model back and use the .predict method, I get the following AttributeError:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-ac7b53278aaf> in <module>
----> 1 Ypred = model.predict(X_test, split=True)
      2 
      3 # Ypred = backend.to_numpy(Ypred)

~/miniconda3/envs/proj-pcan/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
    117 
    118         # lambda, but not partial, allows help() to work with update_wrapper
--> 119         out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
    120         # update the docstring of the returned function
    121         update_wrapper(out, self.fn)

~/miniconda3/envs/proj-pcan/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
    405         Xt = X
    406         for _, name, transform in self._iter(with_final=False):
--> 407             Xt = transform.transform(Xt)
    408         return self.steps[-1][-1].predict(Xt, **predict_params)
    409 

~/repos/himalaya/himalaya/backend/_utils.py in wrapper(*args, **kwargs)
     95     def wrapper(*args, **kwargs):
     96         # skip if the object does not force cpu use
---> 97         if not args[0].force_cpu:
     98             return func(*args, **kwargs)
     99 

AttributeError: 'ColumnKernelizer' object has no attribute 'force_cpu'

~I'm not sure if this is caused by some backward incompatibility, but I will try to submit a fix soon.~ This is probably caused by a missing force_cpu attr in the ColumnKernelizer that gests unpickled by joblib.load.