Closed brando90 closed 2 years ago
this runs but assert fails:
from pathlib import Path
import learn2learn as l2l
root = Path("~/data/").expanduser()
# root = Path(".").expanduser()
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
train = l2l.data.MetaDataset(train)
print(f'{len(train.labels)=}')
# valid = torchvision.datasets.CIFAR100(root="/tmp/mnist", mode="validation")
# valid = l2l.data.MetaDataset(valid)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
test = l2l.data.MetaDataset(test)
print(f'{len(test.labels)=}')
from learn2learn.data import UnionMetaDataset
# union = UnionMetaDataset([train, valid, test])
union = UnionMetaDataset([train, test])
assert len(union.labels) == 100, f'Error, got instead: {len(union.labels)=}.'
Fails because you need the validation set too to sum to 100.
error: