learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.66k stars 353 forks source link

Are the example incomplete or are even the vision examples actually rather toy-like? Or is it just cifarfs? #302

Closed brando90 closed 2 years ago

brando90 commented 2 years ago

Hi Seba,

Apologies if it seems like a complaint -- it's not. But just wanted to clarify. I was going through the examples and they seemed to lack the standard data transforms that ppl use when running real experiments for a paper. e.g.

        self.mean = [0.5071, 0.4867, 0.4408]
        self.std = [0.2675, 0.2565, 0.2761]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.pretrain = pretrain

        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:

from

https://github.com/WangYueFt/rfs/blob/f8c837ba93c62dd0ac68a2f4019c619aa86b8421/dataset/cifar.py#L26

is this correct? Is the standard data transforms that people usually do are not portrayed in the examples. There are no "serious" examples right?

Thanks for your library! Really loving to be able to run an experiment in 1 day rather than 15...+ I like how it's nicely organized. Kudos!

brando90 commented 2 years ago

also it seems some of the interfaces assume train, val and test set have the same transforms?

brando90 commented 2 years ago

after further inspection mini-imagenet seems to have the literature standard transforms: https://github.com/learnables/learn2learn/blob/6d490783d15f37642a636b13d0fd447294a2089d/learn2learn/vision/benchmarks/mini_imagenet_benchmark.py#L29

# MI
        normalize = Normalize(
            mean=[120.39586422/255.0, 115.59361427/255.0, 104.54012653/255.0],
            std=[70.68188272/255.0, 68.27635443/255.0, 72.54505529/255.0],
        )
        train_data_transforms = Compose([
            ToPILImage(),
            RandomCrop(84, padding=8),
            ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ])
        test_data_transforms = Compose([
            normalize,
        ])

but perhaps cifarfs is the only one that does not?

brando90 commented 2 years ago

I guess what I am looking for is the official way to pass my own transforms to the data sets from l2l

seba-1511 commented 2 years ago

Hi @brando90,

Indeed, we've implemented some of the most common transforms used in the literature.

Regarding your last question: the point is to exactly replicate common benchmarks from the literature, so we're explicitly not allowing people to pass their own transforms. I recommend implementing your own benchmarks (like you already did in your other issue).

brando90 commented 2 years ago

to avoid cross posting to much. It seems to me the cirfar-fs transforms are missing something but your mini-imagenet are not.

On Feb 7, 2022, at 6:49 PM, Séb Arnold @.***> wrote:

Hi @brando90 https://github.com/brando90,

Indeed, we've implemented some of the most common transforms used in the literature.

Regarding your last question: the point is to exactly replicate common benchmarks from the literature, so we're explicitly not allowing people to pass their own transforms. I recommend implementing your own benchmarks (like you already did in your other issue).

— Reply to this email directly, view it on GitHub https://github.com/learnables/learn2learn/issues/302#issuecomment-1032097722, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LUZG4OEO7KZMHOEMUDU2BSBNANCNFSM5NUC2DRA. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you were mentioned.

brando90 commented 2 years ago

@seba-1511 I still think this is wrong.

Cifarfs needs to have data transfroms. The standard one (from rfs) is the following:

# - cifar100 normalization according to rfs
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
normalize_cifar100 = transforms.Normalize(mean=mean, std=std)

def get_transform(augment: bool):
    if augment:
        transform = transforms.Compose([
            # lambda x: Image.fromarray(x),
            transforms.RandomCrop(32, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            # lambda x: np.asarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    else:
        transform = transforms.Compose([
            # lambda x: Image.fromarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    return transfor
brando90 commented 2 years ago

probably needs to debugging but here is an attempt for the whole code:

"""
refs:
    - inspired from: https://github.com/WangYueFt/rfs/blob/master/train_supervised.py
    - normalization: https://github.com/WangYueFt/rfs/blob/f8c837ba93c62dd0ac68a2f4019c619aa86b8421/dataset/transform_cfg.py#L104
"""
from argparse import Namespace
from pathlib import Path
from typing import Optional, Callable

import torch
from PIL.Image import Image
from torch.utils.data import DataLoader
from torchvision.transforms import transforms, Compose

import os
import pickle
from PIL import Image
import numpy as np

import torchvision.transforms as transforms
from torch.utils.data import Dataset

# - cifar100 normalization according to rfs
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
normalize_cifar100 = transforms.Normalize(mean=mean, std=std)

def get_train_valid_test_data_loader_helper_for_cifarfs(args: Namespace) -> dict:
    train_kwargs = {'args': args,
                    'data_path': args.data_path,
                    'batch_size': args.batch_size,
                    'batch_size_eval': args.batch_size_eval,
                    'augment_train': args.augment_train,
                    'augment_val': args.augment_val,
                    'num_workers': args.num_workers,
                    'pin_memory': args.pin_memory,
                    'rank': args.rank,
                    'world_size': args.world_size,
                    'merge': None
                    }
    dataloaders: dict = get_rfs_union_sl_dataloader_cifarfs(**train_kwargs)
    return dataloaders

def get_transform(augment: bool):
    if augment:
        transform = transforms.Compose([
            # lambda x: Image.fromarray(x),
            transforms.RandomCrop(32, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            # lambda x: np.asarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    else:
        transform = transforms.Compose([
            # lambda x: Image.fromarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    return transform

def get_transform_rfs(augment: bool):
    """
    this won't work for l2l data sets.
    """
    if augment:
        transform = transforms.Compose([
            lambda x: Image.fromarray(x),
            transforms.RandomCrop(32, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            lambda x: np.asarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    else:
        transform = transforms.Compose([
            lambda x: Image.fromarray(x),
            transforms.ToTensor(),
            normalize_cifar100
        ])
    return transform

class CIFAR100(Dataset):
    """support FC100 and CIFAR-FS"""

    def __init__(self, data_root, data_aug,
                 partition='train', pretrain=True,
                 is_sample=False, k=4096,
                 transform=None):
        super(Dataset, self).__init__()
        # self.data_root = args.data_root
        self.data_root = data_root
        self.partition = partition
        # self.data_aug = args.data_aug
        self.data_aug = data_aug
        self.mean = [0.5071, 0.4867, 0.4408]
        self.std = [0.2675, 0.2565, 0.2761]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.pretrain = pretrain

        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.transform = transform

        if self.pretrain:
            self.file_pattern = '%s.pickle'
        else:
            self.file_pattern = '%s.pickle'
        self.data = {}

        with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
            self.imgs = data['data']
            labels = data['labels']
            # adjust sparse labels to labels from 0 to n.
            cur_class = 0
            label2label = {}
            for idx, label in enumerate(labels):
                if label not in label2label:
                    label2label[label] = cur_class
                    cur_class += 1
            new_labels = []
            for idx, label in enumerate(labels):
                new_labels.append(label2label[label])
            self.labels = new_labels

        # pre-process for contrastive sampling
        self.k = k
        self.is_sample = is_sample
        if self.is_sample:
            self.labels = np.asarray(self.labels)
            self.labels = self.labels - np.min(self.labels)
            num_classes = np.max(self.labels) + 1

            self.cls_positive = [[] for _ in range(num_classes)]
            for i in range(len(self.imgs)):
                self.cls_positive[self.labels[i]].append(i)

            self.cls_negative = [[] for _ in range(num_classes)]
            for i in range(num_classes):
                for j in range(num_classes):
                    if j == i:
                        continue
                    self.cls_negative[i].extend(self.cls_positive[j])

            self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
            self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
            self.cls_positive = np.asarray(self.cls_positive)
            self.cls_negative = np.asarray(self.cls_negative)

    def __getitem__(self, item):
        img = np.asarray(self.imgs[item]).astype('uint8')
        img = self.transform(img)
        target = self.labels[item] - min(self.labels)

        if not self.is_sample:
            return img, target, item
        else:
            pos_idx = item
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, item, sample_idx

    def __len__(self):
        return len(self.labels)

def get_rfs_union_sl_dataloader_cifarfs(args: Namespace,
                                        data_path: Path,
                                        batch_size: int = 128,
                                        batch_size_eval: int = 64,
                                        augment_train: bool = True,
                                        augment_val: bool = False,
                                        num_workers: int = -1,
                                        pin_memory: bool = False,
                                        rank: int = -1,
                                        world_size: int = 1,
                                        merge: Optional[Callable] = None,
                                        ) -> dict:
    """
    ref:
        - https://github.com/WangYueFt/rfs/blob/master/train_supervised.py
    """
    assert rank == -1 and world_size == 1, 'no DDP yet. Need to change dl but its not needed in (small) sl.'

    # args.num_workers = 2 if args.num_workers is None else args.num_workers
    # args.target_type = 'classification'
    # args.data_aug = True
    data_root: str = str(data_path)

    # -- get SL dataloaders
    train_trans, val_trans = get_transform_rfs(augment_train), get_transform_rfs(augment_val)
    train_loader = DataLoader(CIFAR100(data_root=data_root, data_aug=augment_train, partition='train',
                                       transform=train_trans),
                              batch_size=batch_size, shuffle=True, drop_last=True,
                              num_workers=num_workers)
    val_loader = DataLoader(CIFAR100(data_root=data_path, data_aug=augment_val, partition='val',
                                     transform=val_trans),
                            batch_size=batch_size_eval, shuffle=True, drop_last=False,
                            num_workers=num_workers)
    # test_loader = None  # note: since we are evaluating with meta-learning not SL it doesn't need to have this
    test_trans = get_transform_rfs(False)
    test_loader = DataLoader(CIFAR100(data_root=data_path, data_aug=test_trans, partition='test',
                                      transform=val_trans),
                             batch_size=batch_size_eval, shuffle=True, drop_last=False,
                             num_workers=num_workers)

    # -- get meta-dataloaders
    # not needed, we will not evaluate while running the model the meta-test error, that is done seperately.

    # - for now since torchmeta always uses the meta-train or meta-val (but not both together) we won't allow the merge
    args.n_cls = 64

    # - return data loaders
    dataloaders: dict = {'train': train_loader, 'val': val_loader, 'test': test_loader}
    return dataloaders

# def get_rfs_union_sl_dataloader_fc100(args: Namespace) -> dict:
#     n_cls = 60

# -

def cifarfs_tasksets(
        train_ways=5,
        train_samples=10,
        test_ways=5,
        test_samples=10,
        root='~/data',
        data_augmentation=None,
        device=None,
        **kwargs,
):
    import torchvision as tv
    import learn2learn as l2l

    from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

    from torchvision.transforms import (Compose, ToPILImage, ToTensor, RandomCrop, RandomHorizontalFlip,
                                        ColorJitter, Normalize)
    """Tasksets for CIFAR-FS benchmarks."""
    if data_augmentation is None:
        train_data_transforms = tv.transforms.ToTensor()
        test_data_transforms = tv.transforms.ToTensor()
    elif data_augmentation == 'normalize':
        train_data_transforms = Compose([
            lambda x: x / 255.0,
        ])
        test_data_transforms = train_data_transforms
    elif data_augmentation == 'rfs2020':
        train_data_transforms: list[Callable] = get_transform(augment=True)
        test_data_transforms: list[Callable] = get_transform(augment=False)
    else:
        raise ('Invalid data_augmentation argument.')

    train_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                                transform=train_data_transforms,
                                                mode='train',
                                                download=True)
    valid_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                                transform=test_data_transforms,
                                                mode='validation',
                                                download=True)
    test_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                               transform=test_data_transforms,
                                               mode='test',
                                               download=True)
    if device is not None:
        train_dataset = l2l.data.OnDeviceDataset(
            dataset=train_dataset,
            device=device,
        )
        valid_dataset = l2l.data.OnDeviceDataset(
            dataset=valid_dataset,
            device=device,
        )
        test_dataset = l2l.data.OnDeviceDataset(
            dataset=test_dataset,
            device=device,
        )
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms

def fc100_tasksets(
        train_ways=5,
        train_samples=10,
        test_ways=5,
        test_samples=10,
        root='~/data',
        data_augmentation=None,
        device=None,
        **kwargs,
):
    import torchvision as tv
    import learn2learn as l2l

    from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

    from torchvision.transforms import (Compose, ToPILImage, ToTensor, RandomCrop, RandomHorizontalFlip,
                                        ColorJitter, Normalize)
    """Tasksets for FC100 benchmarks."""
    if data_augmentation is None:
        train_data_transforms = tv.transforms.ToTensor()
        test_data_transforms = tv.transforms.ToTensor()
    elif data_augmentation == 'normalize':
        train_data_transforms = Compose([
            lambda x: x / 255.0,
        ])
        test_data_transforms = train_data_transforms
    elif data_augmentation == 'rfs2020':
        train_data_transforms = get_transform(True)
        test_data_transforms = get_transform(False)
    else:
        raise ('Invalid data_augmentation argument.')

    train_dataset = l2l.vision.datasets.FC100(root=root,
                                              transform=train_data_transforms,
                                              mode='train',
                                              download=True)
    valid_dataset = l2l.vision.datasets.FC100(root=root,
                                              transform=train_data_transforms,
                                              mode='validation',
                                              download=True)
    test_dataset = l2l.vision.datasets.FC100(root=root,
                                             transform=test_data_transforms,
                                             mode='test',
                                             download=True)
    if device is not None:
        train_dataset = l2l.data.OnDeviceDataset(
            dataset=train_dataset,
            device=device,
        )
        valid_dataset = l2l.data.OnDeviceDataset(
            dataset=valid_dataset,
            device=device,
        )
        test_dataset = l2l.data.OnDeviceDataset(
            dataset=test_dataset,
            device=device,
        )
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms

_TASKSETS = {
    # 'omniglot': omniglot_tasksets,
    # 'mini-imagenet': mini_imagenet_tasksets,
    # 'tiered-imagenet': tiered_imagenet_tasksets,
    'fc100': fc100_tasksets,
    'cifarfs': cifarfs_tasksets,
}

def get_tasksets(
        name,
        train_ways=5,
        train_samples=10,
        test_ways=5,
        test_samples=10,
        num_tasks=-1,
        root='~/data',
        data_augmentation=None,
        device=None,
        **kwargs,
):
    """
    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/benchmarks/)

    **Description**

    Returns the tasksets for a particular benchmark, using literature standard data and task transformations.

    The returned object is a namedtuple with attributes `train`, `validation`, `test` which
    correspond to their respective TaskDatasets.
    See `examples/vision/maml_miniimagenet.py` for an example.

    **Arguments**

    * **name** (str) - The name of the benchmark. Full list in `list_tasksets()`.
    * **train_ways** (int, *optional*, default=5) - The number of classes per train tasks.
    * **train_samples** (int, *optional*, default=10) - The number of samples per train tasks.
    * **test_ways** (int, *optional*, default=5) - The number of classes per test tasks. Also used for validation tasks.
    * **test_samples** (int, *optional*, default=10) - The number of samples per test tasks. Also used for validation tasks.
    * **num_tasks** (int, *optional*, default=-1) - The number of tasks in each TaskDataset.
    * **device** (torch.Device, *optional*, default=None) - If not None, tasksets are loaded as Tensors on `device`.
    * **root** (str, *optional*, default='~/data') - Where the data is stored.

    **Example**
    ~~~python
    train_tasks, validation_tasks, test_tasks = l2l.vision.benchmarks.get_tasksets('omniglot')
    batch = train_tasks.sample()

    or:

    tasksets = l2l.vision.benchmarks.get_tasksets('omniglot')
    batch = tasksets.train.sample()
    ~~~
    """
    import learn2learn as l2l

    from learn2learn.vision.benchmarks import BenchmarkTasksets
    # - unchanged l2l code, what I changed is what _TASKSETS has
    root = os.path.expanduser(root)

    # Load task-specific data and transforms
    datasets, transforms = _TASKSETS[name](train_ways=train_ways,
                                           train_samples=train_samples,
                                           test_ways=test_ways,
                                           test_samples=test_samples,
                                           root=root,
                                           data_augmentation=data_augmentation,
                                           device=device,
                                           **kwargs)
    train_dataset, validation_dataset, test_dataset = datasets
    train_transforms, validation_transforms, test_transforms = transforms

    # Instantiate the tasksets
    train_tasks = l2l.data.TaskDataset(
        dataset=train_dataset,
        task_transforms=train_transforms,
        num_tasks=num_tasks,
    )
    validation_tasks = l2l.data.TaskDataset(
        dataset=validation_dataset,
        task_transforms=validation_transforms,
        num_tasks=num_tasks,
    )
    test_tasks = l2l.data.TaskDataset(
        dataset=test_dataset,
        task_transforms=test_transforms,
        num_tasks=num_tasks,
    )
    return BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)

# - get SL from l2l

def get_sl_l2l_datasets(root,
                        data_augmentation: str = 'rfs2020',
                        device=None
                        ) -> tuple:
    if data_augmentation is None:
        train_data_transforms = transforms.ToTensor()
        test_data_transforms = transforms.ToTensor()
        raise ValueError('only rfs2020 augmentation allowed')
    elif data_augmentation == 'normalize':
        train_data_transforms = Compose([
            lambda x: x / 255.0,
        ])
        test_data_transforms = train_data_transforms
        raise ValueError('only rfs2020 augmentation allowed')
    elif data_augmentation == 'rfs2020':
        train_data_transforms = get_transform(True)
        test_data_transforms = get_transform(False)
    else:
        raise ('Invalid data_augmentation argument.')

    import learn2learn
    train_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                        transform=train_data_transforms,
                                                        mode='train',
                                                        download=True)
    valid_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                        transform=train_data_transforms,
                                                        mode='validation',
                                                        download=True)
    test_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                       transform=test_data_transforms,
                                                       mode='test',
                                                       download=True)
    if device is not None:
        train_dataset = learn2learn.data.OnDeviceDataset(
            dataset=train_dataset,
            device=device,
        )
        valid_dataset = learn2learn.data.OnDeviceDataset(
            dataset=valid_dataset,
            device=device,
        )
        test_dataset = learn2learn.data.OnDeviceDataset(
            dataset=test_dataset,
            device=device,
        )
    return train_dataset, valid_dataset, test_dataset

def get_sl_l2l_cifarfs_dataloaders(args: Namespace) -> dict:
    train_dataset, valid_dataset, test_dataset = get_sl_l2l_datasets(root=args.data_path)

    from uutils.torch_uu.dataloaders.common import get_serial_or_distributed_dataloaders
    train_loader, val_loader = get_serial_or_distributed_dataloaders(
        train_dataset=train_dataset,
        val_dataset=valid_dataset,
        batch_size=args.batch_size,
        batch_size_eval=args.batch_size_eval,
        rank=args.rank,
        world_size=args.world_size
    )
    _, test_loader = get_serial_or_distributed_dataloaders(
        train_dataset=test_dataset,
        val_dataset=test_dataset,
        batch_size=args.batch_size,
        batch_size_eval=args.batch_size_eval,
        rank=args.rank,
        world_size=args.world_size
    )

    args.n_cls = 64
    dataloaders: dict = {'train': train_loader, 'val': val_loader, 'test': test_loader}
    return dataloaders

# - tests

def l2l_sl_dl():
    print('starting...')
    args = Namespace(data_path='~/data/l2l_data/', batch_size=8, batch_size_eval=2, rank=-1, world_size=1)
    args.data_path = Path('~/data/l2l_data/').expanduser()
    dataloaders = get_sl_l2l_cifarfs_dataloaders(args)
    max_val = torch.tensor(-1)
    for i, batch in enumerate(dataloaders['train']):
        # print(batch[1])
        # print(batch[0])
        max_val = max(list(batch[1]) + [max_val])
        # print(f'{max_val}')
        # if 63 in batch[1]:
        #     break
    print(f'--> TRAIN FINAL: {max_val=}')
    assert max_val == len(dataloaders['train'].dataset)

    max_val = torch.tensor(-1)
    for i, batch in enumerate(dataloaders['val']):
        # print(batch[1])
        max_val = max(list(batch[1]) + [max_val])
        # print(f'{max_val}')
        # if 15 in batch[1]:
        #     break
    print(f'--> VAL FINAL: {max_val=}')
    assert max_val == len(dataloaders['val'].dataset)

    max_val = torch.tensor(-1)
    for i, batch in enumerate(dataloaders['test']):
        # print(batch[1])
        max_val = max(list(batch[1]) + [max_val])
        # print(f'{max_val}')
        # if 19 in batch[1]:
        #     break
    print(f'--> TEST FINAL: {max_val=}')

    assert max_val == len(dataloaders['test'].dataset)

def rfs_test():
    args = Namespace()
    # args = lambda x: None
    # args.n_ways = 5
    # args.n_shots = 1
    # args.n_queries = 12
    # args.data_root = 'data'
    args.data_root = Path('~/data/CIFAR-FS/').expanduser()
    args.data_aug = True
    # args.n_test_runs = 5
    # args.n_aug_support_samples = 1
    imagenet = CIFAR100(args.data_root, args.data_aug, 'train')
    print(len(imagenet))
    print(imagenet.__getitem__(500)[0].shape)

    # metaimagenet = MetaCIFAR100(args, 'train')
    # print(len(metaimagenet))
    # print(metaimagenet.__getitem__(500)[0].size())
    # print(metaimagenet.__getitem__(500)[1].shape)
    # print(metaimagenet.__getitem__(500)[2].size())
    # print(metaimagenet.__getitem__(500)[3].shape)

if __name__ == '__main__':
    l2l_sl_dl()
    print('Done!\a\n')

I think you can find a working version in my ultimate-utils library.

brando90 commented 1 year ago

@seba-1511 why doesn't cifar have data augmentation implemented?