Closed xuesongwang closed 4 years ago
Hi! len(dataset)
returns the number of tasks in the dataset. For a dataset like Omniglot, each task is created as a combination of multiple classes (this is a CombinationMetaDataset
in Torchmeta), so the number of tasks you can create is combinatorially large.
In the case of Omniglot, you have 1028 classes in the meta-training split. In fact, since there are class augmentations (rotations of 90/180/270 degrees), this increases the pool of possible classes to 1028 * 4 = 4112
. Since you want to create 6-way tasks, the total number of tasks is C(4112, 6) = 4112! / (6! * 4106!) = 6689615974037915648
. For a larger number of ways, this number is so large it doesn't fit in an int anymore.
Thank you! That makes perfect sense. And if I set shuffle = False
in data_loader, I can make sure that even if I test several methods on partial data set, they are using the same chunk of data.
shuffle=False
in the data-loader means that the tasks will be returned in a sequential order. For example in Omniglot, if the classes are between 0
and 4111
, then the tasks will be returned in this order (for 6-way classification, where the tuple means which classes are selected for the task):
(0, 1, 2, 3, 4, 5) -> (0, 1, 2, 3, 4, 6) -> (0, 1, 2, 3, 4, 7) -> ... -> (4105, 4107, 4108, 4109, 4110, 4111) -> (4106, 4107, 4108, 4109, 4010, 4111)
So they are not using the same chunk of data.
Well, then I'd better set shuffle = True. But how do I know my algorithm is comparable with others if we are using different tasks? I saw that in your MAML demo and a claimed robust implementation in https://github.com/tristandeleu/pytorch-maml/ , you randomly picked 100 tasks as training episode. I was wondering if in practice, it does not make such a difference with a big enough epoch. Again, thank you for being patient and kind.
In practice, it won't make much of a difference assuming you do N trials with each of M seeds and consider the range.
Hi, I noticed the description of the dataset: "The Omniglot dataset [1]. A dataset of 1623 handwritten characters from 50 different alphabets. The meta train/validation/test splits used in [3] are taken from [this repository]. These splits are over 1028/172/423 classes " . However, when I tested
dataset = omniglot(data_path, ways=6, shots=10, test_shots=15, meta_train=True, download=True)
print(len(dataset))
the result is this huge number: 6689615974037915648, Also, when I use dataloader, the number of total batch seems not equal 1028. If I change argument "ways" to bigger than 10, i get this bug: "OverflowError: cannot fit 'int' into an index-sized integer". Is there any misunderstanding of this dataset here? Thanks for your explanation