Closed renesax14 closed 4 years ago
The datasets in Torchmeta are responsible for creating the episodes, so there is currently no way to create a non-episodic version of the dataset (as in iterating over the images without creating episodes). That's why it doesn't behave well with the standard PyTorch DataLoader
class out of the box. However the data is there, so there should be a way to add a wrapper around the dataset to convert it to a non-episodic dataset. I will give that a try, thanks for the suggestion!
The datasets in Torchmeta are responsible for creating the episodes, so there is currently no way to create a non-episodic version of the dataset (as in iterating over the images without creating episodes). That's why it doesn't behave well with the standard PyTorch
DataLoader
class out of the box. However the data is there, so there should be a way to add a wrapper around the dataset to convert it to a non-episodic dataset. I will give that a try, thanks for the suggestion!
If I wrapped the meta-set with a normal pytorch dataloader do what I want?
Not exactly, because the standard PyTorch DataLoader
uses certain defaults which is not quite compatible with PyTorch datasets. See https://github.com/tristandeleu/pytorch-meta/issues/76#issuecomment-656899638.
On master, you can now use the NonEpisodicWrapper
to wrap a Torchmeta dataset into something which will be compatible with the defaults in DataLoader
. For example
from torchmeta.datasets.helpers import miniimagenet
from torchmeta.utils.data import NonEpisodicWrapper
from torch.utils.data import DataLoader
dataset = miniimagenet('data', ways=5, shots=5, meta_train=True, download=True)
dataset = NonEpisodicWrapper(dataset)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)
for inputs, targets in dataloader:
print(f'inputs.shape = {inputs.shape}') # inputs.shape = torch.Size([16, 3, 84, 84])
targets, class_augmentations = targets
print(f'targets = {targets}') # targets = ('n03400231', 'n04258138', 'n03888605', 'n04389033', 'n03400231', 'n04243546', 'n02823428', 'n02105505', 'n03908618', 'n02747177', 'n02101006', 'n01770081', 'n03476684', 'n02687172', 'n02966193', 'n04435653')
print(f'class_augmentations = {class_augmentations}') # class_augmentations = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
break
I wanted to pre-train a model using normal training (no meta-learning or inner adaptation or anything. Just an old fashion pytorch dataloader but use a few-shot learning data set for the pre-training with standard epoch training).
For that do I just do: