gathierry / FastFlow

Apache License 2.0
124 stars 22 forks source link

The code might need data augmentation #9

Closed cytotoxicity8 closed 2 years ago

cytotoxicity8 commented 2 years ago

As you can see in "6.2 Training Data Augmentation", the authors applied some data augmentation methods. Even though they introduced the results without data augmentation methods, table 2 seems to be with the methods. (Compare AUC in the table 8.) So I think the dataset.py should be changed like below: (*I didn't apply random rotation.)

class MVTecDataset(torch.utils.data.Dataset):
    def __init__(self, root, category, input_size, is_train=True):
        self.image_transform = transforms.Compose(
            [
                transforms.Resize(input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        if is_train:
            #edited!
            self.augment_transform = transforms.Compose(
                [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.3)
                ]
            )
            self.image_files = glob(
                os.path.join(root, category, "train", "good", "*.png")
            )
        else:
            self.image_files = glob(os.path.join(root, category, "test", "*", "*.png"))
            self.target_transform = transforms.Compose(
                [
                    transforms.Resize(input_size),
                    transforms.ToTensor(),
                ]
            )
        self.is_train = is_train

    def __getitem__(self, index):
        image_file = self.image_files[index]
        image = Image.open(image_file)
        image = self.image_transform(image)
        if self.is_train:
            image = self.augment_transform(image)
            return image
        else:
            if os.path.dirname(image_file).endswith("good"):
                target = torch.zeros([1, image.shape[-2], image.shape[-1]])
            else:
                target = Image.open(
                    image_file.replace("/test/", "/ground_truth/").replace(
                        ".png", "_mask.png"
                    )
                )
                target = self.target_transform(target)
            return image, target

    def __len__(self):
        return len(self.image_files)
Asura-Ace commented 2 years ago

I've tried Horizontal Flip, VerticalFlip, and random rotation, but performance drops across multiple categories.

cytotoxicity8 commented 2 years ago

Yes it is true. The authors also mentioned it. I think the reason is based on the symmetry of many categories. However the authors didn't let us know categories that they applied data augmentation. So I guess data augmentation should be an option of the code.

gathierry commented 2 years ago

Nice catch @cytotoxicity8 . I didn't perform data augmentation because the paper didn't give enough details on what kind of augmentation on which categories. However, the paper shows that performance without augmentation doesn't have a significant drop. So I just omitted it