scikit-adaptation / skada

Domain adaptation toolbox compatible with scikit-learn and pytorch
https://scikit-adaptation.github.io/
BSD 3-Clause "New" or "Revised" License
54 stars 16 forks source link

[MRG] Add kwargs to DASVM predict/predict_proba + add score func #160

Closed YanisLalou closed 2 months ago

YanisLalou commented 3 months ago
codecov[bot] commented 3 months ago

Codecov Report

Merging #160 (d408d87) into main (ff06a8b) will decrease coverage by 0.04%. The diff coverage is 92.30%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #160 +/- ## ========================================== - Coverage 97.39% 97.36% -0.04% ========================================== Files 48 48 Lines 4422 4433 +11 ========================================== + Hits 4307 4316 +9 - Misses 115 117 +2 ```
YanisLalou commented 3 months ago

You add the **kwargs in the definition of the function but forgot to give them to self.base_estimator.predict_proba(X, **kwargs) By doing so we have E TypeError: predict_proba() got an unexpected keyword argument 'sample_domain' because BaseSVC doesnt accept ***kwargs.

2 possible solutions:

# If kwargs are accepted, pass them to predict_proba
        try:
            return predict_proba(X, **kwargs)
        except TypeError:
            # If kwargs are not accepted, call predict_proba without kwargs
            return predict_proba(X)

OR

# Check if 'kwargs' is among the parameters
    if 'kwargs' in predict_proba_signature.parameters:
        # If 'kwargs' is accepted, pass it to predict_proba
        return predict_proba(X, **kwargs)
    elif 'sample_domain' in predict_proba_signature.parameters:
        # If 'sample_domain' is accepted, pass it along with X
        return predict_proba(X, sample_domain)
    else:
        # If neither 'kwargs' nor 'sample_domain' is accepted, call predict_proba without them
        return predict_proba(X)

What do you think ?

YanisLalou commented 3 months ago

OR

we can change the metadata for this specific object i.e:

    __metadata_request__fit = {'sample_domain': True}
    __metadata_request__partial_fit = {'sample_domain': False}
    __metadata_request__predict = {'sample_domain': False, 'allow_source': False}
    __metadata_request__predict_proba = {'sample_domain': False, 'allow_source': False}
    __metadata_request__predict_log_proba = {
        'sample_domain': False,
        'allow_source': False
    }
    __metadata_request__score = {'sample_domain': False, 'allow_source': False}
    __metadata_request__decision_function = {
        'sample_domain': False,
        'allow_source': False
    }