tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.97k stars 256 forks source link

What's the meaning of the argument 'dataset_transform' in api "BatchMetaDataLoader"? #120

Closed MrDavidG closed 2 years ago

MrDavidG commented 3 years ago

Thanks for the library, it makes meta training much easier than before.

I wonder what's the meaning of 'dataset_transform' in BatchMetaDataLoader. According to the following toy example, dataset_transform is like to specific the number of samples of tasks during training/test

from torchmeta.transforms import ClassSplitter
from torchmeta.datasets import Omniglot

transform = ClassSplitter(num_samples_per_class={'train': 5, 'test': 15})
dataset = Omniglot('data', num_classes_per_task=5, dataset_transform = transform, meta_train = True)
task = dataset.sample_task()
print(task.keys())
print(len(task['train']), len(task['test']))

However, here is my example. I have set test_shots and shots in api miniimagenet, so do I need to set dataset_transform for BatchMetaDataLoader?

import torchmeta

dataset = torchmeta.datasets.helpers.miniimagenet(
        folder=data_path,
        shots=5,
        ways=5,
        test_shots=15,
        shuffle=True,
        meta_split='train',
        transform=transform,
        download=True
)
dataloader = BatchMetaDataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=8,
        pin_memory=True
)

Also, does the argument shuffle have the same use in the both functions?

tristandeleu commented 3 years ago

There is only a dataset_transform argument for a MetaDataset (the object containing the dataset). This is a transformation which is applied to a dataset; the major use-case in Torchmeta is to implement the split of data into train/test (support/query) sets. There is however no dataset_transform argument in BatchMetaDataLoader (the data-loader, which just iterates over the dataset).