VisionLearningGroup / taskcv-2017-public

169 stars 36 forks source link

Learning By Ignoring - Problem running on different dataset #14

Closed bricksaver closed 3 years ago

bricksaver commented 3 years ago

I have a question. I am currently trying to run "Learning-By-Ignoring" on the visda2017 dataset which is downloadable the same way office31 and officehome datasets are imported in the Learning By Ignoring code using the following line: from dalib.vision.datasets import Office31, OfficeHome, VisDA2017

Currently, I am faced with the following issue that the below line of code at line 195 gives the error shown below. image.png image.png I am really not too sure what is causing this. I checked and my visda data is all in the same structure as office31 (which I was able to run). The lines of code shown above are identical to that used to run office31. I saw visda was included as a possible option for --dataset. I am really not sure what could be causing this and was wondering if you had any input on this professor. Sorry, I know this is a debugging problem, but I'm just really stuck. I attached my code for reference. I use the following command to run the code: python main_custom.py --save_dir=T_V_ours2 --gpu=0 --source_domain=T --target_domain=V --dataset=visda2017 --ours2 --lam=5e-4 --batch_size=32 and get the following output in my 'log.txt' file before the error appears and prevents it from running any further: POINT 6.2 POINT 7 POINT 8 POINT 9 data_root: temp/data/visda2017 POINT 10 POINT 10.1 POINT 10.2 POINT 11 train source_task: T task: T_train

Please let me know if there is any additional information which would be helpful to know. My code is below for reference:

bricksaver commented 3 years ago

Oh actually I solved it. Inside the visda2017.py code, the following init definition was missing the following lines:

if task == 'T_train':
            domain_idx = 0
        elif task == 'V_train':
            domain_idx = 1
        else:
            domain_idx = -1

AND

domain_idx=domain_idx,

BELOW is the full function init definition

def __init__(self,
                 root: str,
                 task: str,
                 download: Optional[bool] = False,
                 **kwargs):
        assert task in self.image_list
        data_list_file = os.path.join(root, self.image_list[task])

        if download:
            list(
                map(lambda args: download_data(root, *args),
                    self.download_list))
        else:
            list(
                map(lambda file_name, _: check_exits(root, file_name),
                    self.download_list))

        super(VisDA2017, self).__init__(root,
                                        VisDA2017.CLASSES,
                                        data_list_file=data_list_file,
                                        **kwargs)