Thunderbeee / ZSCL

Preventing Zero-Shot Transfer Degradation in Continual Learning of Vision-Language Models
69 stars 6 forks source link

NotImplementedError while loading Distillation dataset #4

Open YCAyca opened 3 months ago

YCAyca commented 3 months ago

Hello, I try to use zscl method with the suggested datasets (ImageNet and Conceptual Captions). I prepared the distillation dataset using https://github.com/ml-jku/cloob repository so I have "Validation_GCC-1.1.0-Validation_output.csv" file already. But it throws me the following error:

0%| | 0/1301 [00:00<?, ?it/s] Error executing job with overrides: [] Traceback (most recent call last): File "main.py", line 51, in continual_clip model.adaptation(task_id, cfg, train_dataset, train_classes_names) File "/workspace/ZSCL/cil/continual_clip/models.py", line 48, in adaptation self.train(task_id, cfg, train_dataset, train_classes_names) File "/workspace/ZSCL/cil/continual_clip/models.py", line 230, in train ref_images, ref_labels = next(ref_iter) File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 33, in getitem raise NotImplementedError NotImplementedError

I know that the problem is about Distillation Dataset (Conceptual Captions) but I cant fix it. Can anybody help??

Thunderbeee commented 3 months ago

Thanks for reaching out us! Below is the instruction provided by Doctor Zangwei Zheng:

First, download Validation_GCC-1.1.0-Validation.tsv from the Conceptual Captions dataset here. Then, use gather_cc.py to download the images. After running the script, you should get both a folder of images and a Validation_GCC-1.1.0-Validation_output.csv file. The process of Conceptual Captions is the same with this repo.

python gather_cc.py Validation_GCC-1.1.0-Validation.tsv

YCAyca commented 3 months ago

Thanks, but the question is not how to create Validation_GCC-1.1.0-Validation_output.csv, Its totally different. I have already done these steps and I have Validation_GCC-1.1.0-Validation_output.csv file, but the code gives me mentioned error, something about conceptual_captions(Dataset) class should be wrong, but I dont understand why. The error explains that there is no getitem function implemented for conceptual_captions class, but it is implemented. Below is the cc.py I have only changed the path to .csv file and added two print() blocks, apparently initialization is fine but getitem never called.

 class CsvDataset(Dataset):
    def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
        df = pd.read_csv(input_filename, sep=sep)
        print("INIT")
        self.location = os.path.dirname(input_filename)
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.transforms = transforms

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        print("GETITEM")
        image_path = os.path.join(self.location, str(self.images[idx]))
        images = self.transforms(Image.open(image_path))
        texts = clip.tokenize([str(self.captions[idx])])[0]
        return images, texts

class conceptual_captions(Dataset):
    def __init__(
        self, transforms, location, batch_size, *args, num_workers=16, **kwargs
    ):
        file_name = "/workspace/ZSCL/cil/data/Validation_GCC-1.1.0-Validation_output.csv"
        file_path = file_name
        self.template = lambda c: f"a photo of a {c}."
        self.train_dataset = CsvDataset(
            input_filename=file_path,
            transforms=transforms,
            img_key="filepath",
            caption_key="title",
        )
        # breakpoint()
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
        )
YCAyca commented 3 months ago

Okey I found the problem and the solution seems to work now. In cc.py class conceptual_captions(Dataset) is defined already as a class inherited from Dataset, and it initialize an instance of CSVDataset by self.train_dataset = CsvDataset(...) but it never use it.... That's why print("INIT) line was working but not print("GETITEM") line, because when in models.py in line 232 ref_images, ref_labels = next(ref_iter) call made, it calls the getitem of conceptual_captions class, since ref_iter is an instance of this class. But there is no getitem function in this class so it doesnt work. I will make a pull request about that to push the fixed code soon!