Closed just-eoghan closed 2 years ago
@deepseek-eoghan Can you create a new test file at tests/datasets/test_target_dataset.py
?
Something like:
import unittest
from pytorch_adapt.datasets import TargetDataset
class TestTargetDataset(unittest.TestCase):
def test_supervised_flag(self):
# some basic sanity checks, like checking the keys of the output when supervised is True or False
Then run it like python -m unittest tests/datasets/test_target_dataset.py
Sure will do!
I added test_target_dataset.py
I'm not sure I have gone about it in the best way so will refactor based on your feedback if needs be.
106/113 tests are currently passing so I may have broken some other tests if they were all passing previously.
Should we handle the case where a user creates a target dataset with unlabelled data and passes the supervised flag?
We could write an error message to the console telling the user the dataset has no labels so it will be treated as unsupervised, this would flag to the user that their input dataset doesn't contain any labels if they had intended to pass labels and treat the target domain as supervised.
106/113 tests are currently passing so I may have broken some other tests if they were all passing previously.
Hmm, I'll take a look later. I might just need to change some ==
to np.isclose
for it to pass across machines.
Should we handle the case where a user creates a target dataset with unlabelled data and passes the supervised flag?
We could write an error message to the console telling the user the dataset has no labels so it will be treated as unsupervised, this would flag to the user that their input dataset doesn't contain any labels if they had intended to pass labels and treat the target domain as supervised.
Good idea, but I think make it an exception rather than a warning.
Regarding the code, does DomainDataset
need the supervised flag?
Good idea, but I think make it an exception rather than a warning.
Perfect, I'll add that exception to handle that case now.
Regarding the code, does DomainDataset need the supervised flag?
You're right it isn't necessary! I refactored the code to remove it.
Think that's all for now.
Let me know if there is any further refactoring required and I'll jump back on it!
Looks good, thanks!
By the way the tests did pass on my system
I can debug the tests on my side and see why they were failing for me. I'll open a separate issue for that.
Also, I was just looking at the get_mnist_mnistm
functionality, specifically it's use of get_datasets
Would it be good to update this to use our new TargetDataset?
Currently get_dataset
takes a parameter return_target_with_labels
Which if true executes the following
if target_domains:
output["target_train"] = TargetDataset(getter(target_domains, True, False))
output["target_val"] = TargetDataset(getter(target_domains, False, False))
if return_target_with_labels:
output["target_train_with_labels"] = SourceDataset(
getter(target_domains, True, False), domain=1
)
output["target_val_with_labels"] = SourceDataset(
getter(target_domains, False, False), domain=1
)
Could these _with_labels
datasets now be changed to TargetDataset
?
Could these
_with_labels
datasets now be changed toTargetDataset
?
Yeah that makes sense :+1:
Great thanks!
I'll open a new PR
Use type checking and supervised flag to allow target dataset to return labels under supervised conditions.