learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.66k stars 353 forks source link

Extending get_tasksets such that we can provide our own transforms for the train, val, test sets #304

Closed brando90 closed 2 years ago

brando90 commented 2 years ago

Is it possible to expand

def get_tasksets(
    name,
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    num_tasks=-1,
    root='~/data',
    device=None,
    **kwargs,
):

s.t. we can feed the train, val, test transforms we want?

e.g. is

def get_tasksets2(
    name,
    transforms,
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    num_tasks=-1,
    root='~/data',
    device=None,
    **kwargs,
):
    """
    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/benchmarks/)

    **Description**

    Returns the tasksets for a particular benchmark, using literature standard data and task transformations.

    The returned object is a namedtuple with attributes `train`, `validation`, `test` which
    correspond to their respective TaskDatasets.
    See `examples/vision/maml_miniimagenet.py` for an example.

    **Arguments**

    * **name** (str) - The name of the benchmark. Full list in `list_tasksets()`.
    * **train_ways** (int, *optional*, default=5) - The number of classes per train tasks.
    * **train_samples** (int, *optional*, default=10) - The number of samples per train tasks.
    * **test_ways** (int, *optional*, default=5) - The number of classes per test tasks. Also used for validation tasks.
    * **test_samples** (int, *optional*, default=10) - The number of samples per test tasks. Also used for validation tasks.
    * **num_tasks** (int, *optional*, default=-1) - The number of tasks in each TaskDataset.
    * **device** (torch.Device, *optional*, default=None) - If not None, tasksets are loaded as Tensors on `device`.
    * **root** (str, *optional*, default='~/data') - Where the data is stored.

    **Example**
    ~~~python
    train_tasks, validation_tasks, test_tasks = l2l.vision.benchmarks.get_tasksets('omniglot')
    batch = train_tasks.sample()

    or:

    tasksets = l2l.vision.benchmarks.get_tasksets('omniglot')
    batch = tasksets.train.sample()
    ~~~
    """
    root = os.path.expanduser(root)

    # Load task-specific data and transforms
    datasets, transforms = _TASKSETS[name](train_ways=train_ways,
                                           train_samples=train_samples,
                                           test_ways=test_ways,
                                           test_samples=test_samples,
                                           root=root,
                                           device=device,
                                           **kwargs)
    train_dataset, validation_dataset, test_dataset = datasets
    # train_transforms, validation_transforms, test_transforms = transforms

    # Instantiate the tasksets
    train_tasks = l2l.data.TaskDataset(
        dataset=train_dataset,
        task_transforms=train_transforms,
        num_tasks=num_tasks,
    )
    validation_tasks = l2l.data.TaskDataset(
        dataset=validation_dataset,
        task_transforms=validation_transforms,
        num_tasks=num_tasks,
    )
    test_tasks = l2l.data.TaskDataset(
        dataset=test_dataset,
        task_transforms=test_transforms,
        num_tasks=num_tasks,
    )
    return BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)

correct?

seba-1511 commented 2 years ago

This is tricky as it conflicts with the data_augmentation argument of mini-ImageNet and tiered-ImageNet. In fact, the point of get_tasksets is to replicate common benchmarks in the literature so that we can make apples-apples comparisons.

In your case, you might want to define your own benchmark setups with your own transforms like you did (in which case you probably want to use transforms instead of train_transforms, valid_transforms, and test_transforms at the bottom of your implementation).

brando90 commented 2 years ago

hmmm... but most benchmarks use some type of data augmentation...you already have the "right" transforms for Mini-Imagenet. The cifarfs one's seem to give only tensors. Isn't that none standard for cifarfs? (I guess I am arguing the ones I am using for cifarfs are the apples-to-apples one. The mini-imagenet ones seem correct witht he flag you gave). Correct me if I am wrong.

On Feb 7, 2022, at 6:47 PM, Séb Arnold @.***> wrote:

This is tricky as it conflicts with the data_augmentation argument of mini-ImageNet and tiered-ImageNet. In fact, the point of get_tasksets is to replicate common benchmarks in the literature so that we can make apples-apples comparisons.

In your case, you might want to define your own benchmark setups with your own transforms like you did (in which case you probably want to use transforms instead of train_transforms, valid_transforms, and test_transforms at the bottom of your implementation).

— Reply to this email directly, view it on GitHub https://github.com/learnables/learn2learn/issues/304#issuecomment-1032096534, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LV23W2WEHUCBVGFBS3U2BRZVANCNFSM5NUFG5XQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you authored the thread.

seba-1511 commented 2 years ago

Indeed, I think we can add the rfs augmentation you proposed in your PR -- thanks for that! I'm having a look at it now, will comment soon.

brando90 commented 2 years ago

Sounds good. Let me know if you do plan to reject based on not being standard or not being apples to apples. I was proposing those augmentation because I indeed thought they were the standard cifar-fs transformations. I don't know anyone that passes the tensors of the images directly as the current code does -- hence my suggestion. :)

Thanks for l2l, I am really enjoying it! It's simple and the memory usage seems to be much lower.

On Feb 7, 2022, at 6:55 PM, Séb Arnold @.***> wrote:

Indeed, I think we can add the rfs augmentation you proposed in your PR -- thanks for that! I'm having a look at it now, will comment soon.

— Reply to this email directly, view it on GitHub https://github.com/learnables/learn2learn/issues/304#issuecomment-1032100874, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LRVNZF5AGT23BT25ZDU2BSYPANCNFSM5NUFG5XQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you authored the thread.