scikit-adaptation / skada

Domain adaptation toolbox compatible with scikit-learn and pytorch
https://scikit-adaptation.github.io/
BSD 3-Clause "New" or "Revised" License
60 stars 16 forks source link

Add predict_proba to JDOTClassifier #150

Closed YanisLalou closed 6 months ago

YanisLalou commented 6 months ago

To use some skada scorers, we need the JDOTClassifier to have the predict_proba function. Code to add:

def predict_proba(self, X, sample_domain=None, *, sample_weight=None):
        """Predict using the model"""
        check_is_fitted(self)
        if sample_domain is not None and np.any(sample_domain >= 0):
            warnings.warn(
                "Source domain detected. Predictor is trained on target"
                "and prediction might be biased."
            )
        return self.estimator_.predict_proba(X)

+ Error to fix:

  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/sklearn/metrics/_scorer.py", line 141, in __call__
    score = scorer(estimator, *args, **routed_params.get(name).score)
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/skada/metrics.py", line 42, in __call__
    return self._score(estimator, X, y, sample_domain=sample_domain, **params)
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/skada/metrics.py", line 175, in _score
    return self._sign * scorer(
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/sklearn/metrics/_scorer.py", line 415, in __call__
    return estimator.score(*args, **kwargs)
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/sklearn/pipeline.py", line 1007, in score
    return self.steps[-1][1].score(Xt, y, **routed_params[self.steps[-1][0]].score)
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/skada/base.py", line 273, in score
    return self._route_to_estimator('score', X, y=y, **params)
  File "/Users/yanislalou/Documents/CMAP/skada_test_venv/lib/python3.9/site-packages/skada/base.py", line 388, in _route_to_estimator
    output = method(X, **routed_params) if y is None else method(
TypeError: score() got an unexpected keyword argument 'allow_source'

Solution: add kwargs to the score function