KevinMusgrave / pytorch-adapt

Domain adaptation made easy. Fully featured, modular, and customizable.
https://kevinmusgrave.github.io/pytorch-adapt/
MIT License
360 stars 15 forks source link

Extension of the TargetDataset class. #60

Closed just-eoghan closed 2 years ago

just-eoghan commented 2 years ago

Suggested Feature

A) The addition of a new TargetDataset class for supervised domain adaptation.

or

B) The extension of the TargetDataset class to return labels when passed a supervised flag.

I think option B could be cleaner?

Implementation

A) Create a new class capable of returning target_labels named something like SupervisedTargetDataset.

or

B) Update the init function of the TargetDataset to include a supervised flag.

    def __init__(self, dataset: Dataset, domain: int = 1, supervised=False):
        """
        Arguments:
            dataset: The dataset to wrap
            domain: An integer representing the domain.
        """
        super().__init__(dataset, domain, supervised)

Update the getitem method to behave differently under supervised domain adaptation.

    def __getitem__(self, idx: int) -> Dict[str, Any]:

        if supervised:
            img, target_labels = self.dataset[idx]
            return {
                "target_imgs": img,
                "target_domain": self.domain,
                "target_labels": target_labels,
                "target_sample_idx": idx,
            }
        else:
            img, _ = self.dataset[idx]
            return {
                "target_imgs": img,
                "target_domain": self.domain,
                "target_sample_idx": idx,
            }

Reasoning

To run supervised domain adaptation we need to have labels in the target domain but I think it would still be useful to distinguish between the two domains using different classes. Rather than using SourceDataset on a TargetDataset to achieve the same functionality.

With this change validators such as AccuracyValidator could be used on target_val in a supervised domain adaptation setting.


BTW: With these feature suggestions I am happy to do code PRs along with the docs as I previously mentioned!

just-eoghan commented 2 years ago

My bad just saw this previous issue ... #13

Maybe the approach above might work?

KevinMusgrave commented 2 years ago

I think a flag makes sense. There are 3 scenarios I can think of:

So maybe a supervised flag like you've suggested, and a dataset_has_labels flag?

If dataset_has_labels = True then: img, target_labels = self.dataset[idx]. Otherwise it needs to be img = self.dataset[idx]

BTW: With these feature suggestions I am happy to do code PRs along with the docs as I previously mentioned!

Yes that would be great 👍

just-eoghan commented 2 years ago

That looks good to me.

I'll put together a PR on this.

Thanks!

just-eoghan commented 2 years ago

Just working on the PR now.

On this:

If dataset_has_labels = True then: img, target_labels = self.dataset[idx]. Otherwise it needs to be img = self.dataset[idx]

If the user calls TargetDataset with supervised can we assume that the dataset is capable of returning labels and try calling self.dataset with the img, label return?

Something like this.

# import cf into TargetDataset class
from ..utils import common_functions as c_f

if self.supervised:
    try:
        img, labels = self.dataset[idx]
    except ValueError as e:
        # not enough values to unpack error
        c_f.add_error_message(e, f"\nThe dataset does not contain labels supervised domain adaptation is not possible.")

To me it seems like dataset_has_labels is nearly a duplication of supervised ?

We could throw an error to the user if they call TargetDataset with a supervised flag on a dataset which does not have the capability to return labels?

What are your thoughts on this?

KevinMusgrave commented 2 years ago

I'd like it to still support the "academic unsupervised" situation, where supervised is False, but the wrapped dataset returns labels. Maybe you could move the try/except outside the if statement:

    try:
        img, labels = self.dataset[idx]
    except ValueError:
        img = self.dataset[idx]

    # create dict, and add labels to it only if self.supervised is True
KevinMusgrave commented 2 years ago

I don't really like try/except though, because for example if the batch size is 2, and self.dataset returns only imgs, then the above won't throw a ValueError

just-eoghan commented 2 years ago

Good point I didn't consider that.

Is the best approach to revert back to the two flag approach?

KevinMusgrave commented 2 years ago

I agree having to set both supervised=True and dataset_has_labels=True could be annoying. So instead of try/except maybe check type:

img = self.dataset[idx]
if isinstance(img, (list, tuple)):
    img, labels = img
just-eoghan commented 2 years ago

PR created #61

KevinMusgrave commented 2 years ago

Thanks!