scikit-adaptation / skada

Domain adaptation toolbox compatible with scikit-learn and pytorch
https://scikit-adaptation.github.io/
BSD 3-Clause "New" or "Revised" License
54 stars 16 forks source link

Deep copy a DomainAwareNet after fitting #139

Open YanisLalou opened 4 months ago

YanisLalou commented 4 months ago

After fitting a DomainAwareNet, it's not possible to deep copy it. Strangely it works before the fitting. Error raised: AttributeError: Can't pickle local object '_get_intermediate_layers.<locals>.hook'

It's necessary to be able to deep copy it to use the CircularValidation scorer.

To reproduce:

import torch
import numpy as np
from skada.deep.base import DomainAwareCriterion, DomainAwareModule, DomainAwareNet, DomainBalancedDataLoader, BaseDALoss
from skada.metrics import (
    ImportanceWeightedScorer,
    PredictionEntropyScorer,
    SoftNeighborhoodDensity,
    DeepEmbeddedValidation,
    CircularValidation,
)
from skada.deep.modules import ToyModule2D
from skada import make_da_pipeline
from sklearn.model_selection import ShuffleSplit, cross_validate
from skada.datasets import make_shifted_datasets
from copy import deepcopy

class TestLoss(BaseDALoss):
    """Test Loss to check the deep API"""

    def __init__(
        self,
    ):
        super().__init__()

    def forward(
        self,
        *args,
    ):
        """Compute the domain adaptation loss"""
        return 0

da_dataset = make_shifted_datasets(
        n_samples_source=20,
        n_samples_target=20,
        shift="concept_drift",
        noise=0.1,
        random_state=42,
        return_dataset=True,
    )

X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"])

module = ToyModule2D()

estimator = DomainAwareNet(
    DomainAwareModule(module, "dropout"),
    iterator_train=DomainBalancedDataLoader,
    criterion=DomainAwareCriterion(torch.nn.CrossEntropyLoss(), TestLoss()),
    batch_size=10,
    max_epochs=2,
    train_split=None,
)

X = X.astype(np.float32)
X_test = X_test.astype(np.float32)

estimator_copy = deepcopy(estimator) # Doesn't raise errors

estimator.fit(X, y, sample_domain=sample_domain)

estimator_copy_after_fit = deepcopy(estimator) # Raises AttributeError
kachayev commented 4 months ago

We had a brief discussion of this problem in the Slack chat. It also makes behavior of selectors somewhat unpredictable, as they perform clone of the base estimator before fitting. The reasons for copy to fail is that torch blocks copying of the tensor that is attached to the gradients graph. It should be possible to define own cloning hook by taking parameters of the module and calling constructor (not sure if it works for deepcopy as well, but I would rather say it's better to switch splitter to sklearn clone).

YanisLalou commented 2 months ago

Linked to #163 and fixed by #169