Simon4Yan / Meta-set

Automatic model evaluation (AutoEval) in CVPR'21&TPAMI'22
MIT License
35 stars 5 forks source link

Test MNIST trained model on SVHN and USPS #1

Open Simon4Yan opened 2 years ago

Simon4Yan commented 2 years ago

Pretrained MNIST model

USPS

DATASET

class USPS(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        super(USPS, self).__init__()
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        filename = 'usps.bz2' if train else 'usps.t.bz2'
        full_path = os.path.join(self.root, filename)

        import bz2
        with bz2.open(full_path) as fp:
            raw_data = [l.decode().split() for l in fp.readlines()]
            imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
            imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
            targets = [int(d[0]) - 1 for d in raw_data]

        self.data = imgs
        self.targets = targets

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(
        USPS('../raw_data', train=False,
             transform=transforms.Compose([
                 transforms.Resize([28, 28]),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5,), (0.5,))
             ])),
        batch_size=args.test_batch_size, shuffle=False, **kwargs)

SVHN

import torchvision.datasets as dataset

test_loader = torch.utils.data.DataLoader(
    dataset.SVHN('../raw_data', 'test',
                 transform=transforms.Compose([
                     transforms.Resize([28, 28]),
                     transforms.ToTensor(),
                     transforms.Normalize((0.5,), (0.5,))
                 ]), download=True),
    batch_size=args.test_batch_size, shuffle=False, **kwargs)