scikit-adaptation / skada

Domain adaptation toolbox compatible with scikit-learn and pytorch
https://scikit-adaptation.github.io/
BSD 3-Clause "New" or "Revised" License
60 stars 16 forks source link

TarS and MMDSConS methods don't work with the ImportanceWeightedScorer #151

Open YanisLalou opened 6 months ago

YanisLalou commented 6 months ago

Error raised: TypeError: Singleton array array(None, dtype=object) cannot be considered a valid collection.

To reproduce:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import ShuffleSplit, cross_validate
from skada import make_da_pipeline, MMDTarSReweightAdapter
from skada.datasets import make_shifted_datasets
from skada.metrics import ImportanceWeightedScorer
from skada.model_selection import StratifiedDomainShuffleSplit

da_dataset = make_shifted_datasets(
        n_samples_source=10,
        n_samples_target=10,
        noise=None,
        label="multiclass",
        return_dataset=True,
)

estimator = make_da_pipeline(
    MMDTarSReweightAdapter(gamma=1.0),
    LogisticRegression()
    .set_fit_request(sample_weight=True)
    .set_score_request(sample_weight=True),
)

X_train, y_train, sample_domain = da_dataset.pack_train(
    as_sources=["s"], as_targets=["t"]
)

cv = StratifiedDomainShuffleSplit(n_splits=3, test_size=0.3, random_state=0)

scoring = ImportanceWeightedScorer()
scores = cross_validate(
    estimator,
    X_train,
    y_train,
    cv=cv,
    params={"sample_domain": sample_domain},
    scoring=scoring,
)["test_score"]

X_test, y_test, sample_domain = da_dataset.pack_test(as_targets=["t"])
y_pred = estimator.predict(X_test, sample_domain=sample_domain)
score = estimator.score(X_test, y_test, sample_domain=sample_domain)