tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.98k stars 256 forks source link

How to retain the original labels of test/train targets? #157

Open srvCodes opened 2 years ago

srvCodes commented 2 years ago

Hi,

I have been trying to retain the original labels of test/train set targets, for example, the targets in lines 45 and 50 in the protonet training script. Could you please help?

By original labels, I refer to the integer labels before the targets get mapped to the [0, n_way-1] range during few-shot training.

vivektrivedy commented 2 years ago

Any update on this? @srvCodes

vivektrivedy commented 2 years ago

For the existing datasets we can simply pass in target_transform = None to access the original labels.

from torchmeta.datasets.helpers import cifar_fs
from torchmeta.utils.data import BatchMetaDataLoader

cfs = cifar_fs('data', shots=5, ways=5, test_shots=5, meta_train=True, download=True, target_transform = None)
dataloader = BatchMetaDataLoader(cfs, batch_size=1,shuffle=False)

To check:

sample = next(iter(dataloader))
original_labels = sample['train'][1][0][1]