ersilia-os / chempfn

Ensemble-based, size-agnostic wrapper for the TabPFN classifier
GNU General Public License v3.0
28 stars 0 forks source link

Bug: Predict proba #18

Closed GemmaTuron closed 1 year ago

GemmaTuron commented 1 year ago

Hi @DhanshreeA

When I try the predict_proba instead of the predict, I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[13], line 1
----> 1 preds = clf.predict_proba(X_test)

File [~/anaconda3/envs/eosce/lib/python3.8/site-packages/ensemble_tabpfn/ensemble_tabpfn.py:171](https://file+.vscode-resource.vscode-cdn.net/home/gturon/github/ersilia-os/mmv/notebooks/~/anaconda3/envs/eosce/lib/python3.8/site-packages/ensemble_tabpfn/ensemble_tabpfn.py:171), in EnsembleTabPFN.predict_proba(self, X)
    169 result = self._predict(X)
    170 result.aggregate
--> 171 return result.probs

AttributeError: 'Result' object has no attribute 'probs'
DhanshreeA commented 1 year ago

Thanks for flagging this @GemmaTuron I'm pushing a fix + test. The issue here is that I have not actually called the function aggregate which is causing the variable probs to not be set within the result.