mboudiaf / pytorch-meta-dataset

A non-official 100% PyTorch implementation of META-DATASET benchmark for few-shot classification
59 stars 9 forks source link

squeeze() needs to be added to support, support_labels, query, query_labels #4

Closed jfb54 closed 3 years ago

jfb54 commented 3 years ago

When support, query, support_labels, query_labels are returned from the DataLoader, the 1st dimension is 1 in size which is redundant and will usually not work properly when fed into a network. A squeeze(x, dim=0) will fix this.

mboudiaf commented 3 years ago

So the idea here is that the first dimension is dedicated to batching the tasks. There are instances of problems where you may want to create batches of tasks instead of 1 task at a time (only possible if the number of support/query is fixed throughout the tasks). To do so, you can modify batch_size at https://github.com/mboudiaf/pytorch-meta-dataset/blob/5c4e85b149cf7079789190a6326c73bcc7efd1f6/example.py#L131 to be more than 1.

For batch_size > 1, in order to properly work when fed into the network, you will have to play with reshaping. For instance if support is of shape [batch_size, num_support, c, h, w], you can reshape to [batch_size num_support, c, h, w], feed to your network and reshape your output to [batch_size, num_support, ].

jfb54 commented 3 years ago

Got it. Thanks for clarifying. For some reason, meta-dataset (even the learner code) has no concept of a batch, it only knows about tasks and that is what threw me off.