KevinMusgrave / pytorch-adapt

Domain adaptation made easy. Fully featured, modular, and customizable.
https://kevinmusgrave.github.io/pytorch-adapt/
MIT License
359 stars 15 forks source link

Allow TargetDataset to return data with labels #61

Closed just-eoghan closed 2 years ago

just-eoghan commented 2 years ago

Use type checking and supervised flag to allow target dataset to return labels under supervised conditions.

KevinMusgrave commented 2 years ago

@deepseek-eoghan Can you create a new test file at tests/datasets/test_target_dataset.py?

Something like:

import unittest
from pytorch_adapt.datasets import TargetDataset

class TestTargetDataset(unittest.TestCase):
    def test_supervised_flag(self):
        # some basic sanity checks, like checking the keys of the output when supervised is True or False

Then run it like python -m unittest tests/datasets/test_target_dataset.py

just-eoghan commented 2 years ago

Sure will do!

just-eoghan commented 2 years ago

I added test_target_dataset.py

I'm not sure I have gone about it in the best way so will refactor based on your feedback if needs be.

106/113 tests are currently passing so I may have broken some other tests if they were all passing previously.

image

Should we handle the case where a user creates a target dataset with unlabelled data and passes the supervised flag?

We could write an error message to the console telling the user the dataset has no labels so it will be treated as unsupervised, this would flag to the user that their input dataset doesn't contain any labels if they had intended to pass labels and treat the target domain as supervised.

KevinMusgrave commented 2 years ago

106/113 tests are currently passing so I may have broken some other tests if they were all passing previously.

Hmm, I'll take a look later. I might just need to change some == to np.isclose for it to pass across machines.

Should we handle the case where a user creates a target dataset with unlabelled data and passes the supervised flag?

We could write an error message to the console telling the user the dataset has no labels so it will be treated as unsupervised, this would flag to the user that their input dataset doesn't contain any labels if they had intended to pass labels and treat the target domain as supervised.

Good idea, but I think make it an exception rather than a warning.

Regarding the code, does DomainDataset need the supervised flag?

just-eoghan commented 2 years ago

Good idea, but I think make it an exception rather than a warning.

Perfect, I'll add that exception to handle that case now.

Regarding the code, does DomainDataset need the supervised flag?

You're right it isn't necessary! I refactored the code to remove it.

just-eoghan commented 2 years ago

Think that's all for now.

Let me know if there is any further refactoring required and I'll jump back on it!

KevinMusgrave commented 2 years ago

Looks good, thanks!

KevinMusgrave commented 2 years ago

By the way the tests did pass on my system

just-eoghan commented 2 years ago

I can debug the tests on my side and see why they were failing for me. I'll open a separate issue for that.

Also, I was just looking at the get_mnist_mnistm functionality, specifically it's use of get_datasets

Would it be good to update this to use our new TargetDataset?

Currently get_dataset takes a parameter return_target_with_labels

Which if true executes the following

    if target_domains:
        output["target_train"] = TargetDataset(getter(target_domains, True, False))
        output["target_val"] = TargetDataset(getter(target_domains, False, False))
        if return_target_with_labels:
            output["target_train_with_labels"] = SourceDataset(
                getter(target_domains, True, False), domain=1
            )
            output["target_val_with_labels"] = SourceDataset(
                getter(target_domains, False, False), domain=1
            )

Could these _with_labels datasets now be changed to TargetDataset?

KevinMusgrave commented 2 years ago

Could these _with_labels datasets now be changed to TargetDataset?

Yeah that makes sense :+1:

just-eoghan commented 2 years ago

Great thanks!

I'll open a new PR