learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

Code example from union data set doesn't work #356

Closed brando90 closed 1 year ago

brando90 commented 1 year ago
    import learn2learn as l2l
    train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
    train = l2l.data.MetaDataset(train)
    valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
    valid = l2l.data.MetaDataset(valid)
    test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
    test = l2l.data.MetaDataset(test)
    from learn2learn.data import UnionMetaDataset
    union = UnionMetaDataset([train, valid, test])
    assert len(union.labels) == 100

error:

AttributeError: module 'torchvision.datasets' has no attribute 'CIFARFS'
brando90 commented 1 year ago

this runs but assert fails:

    from pathlib import Path
    import learn2learn as l2l

    root = Path("~/data/").expanduser()
    # root = Path(".").expanduser()
    train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
    train = l2l.data.MetaDataset(train)
    print(f'{len(train.labels)=}')
    # valid = torchvision.datasets.CIFAR100(root="/tmp/mnist", mode="validation")
    # valid = l2l.data.MetaDataset(valid)
    test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
    test = l2l.data.MetaDataset(test)
    print(f'{len(test.labels)=}')

    from learn2learn.data import UnionMetaDataset
    # union = UnionMetaDataset([train, valid, test])
    union = UnionMetaDataset([train, test])
    assert len(union.labels) == 100, f'Error, got instead: {len(union.labels)=}.'
seba-1511 commented 1 year ago

Fails because you need the validation set too to sum to 100.