Closed joshuasv closed 1 year ago
Hello @joshuasv,
The fix is for the dataset to return a torch.Tensor
. You also need your images to have 3 dimensions. This should work:
import torch
import numpy as np
import learn2learn as l2l
from torch.utils.data import dataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
from learn2learn.vision.transforms import RandomClassRotation
class TestDataset(dataset.Dataset):
def __init__(self, mode="train"):
super().__init__()
def __len__(self):
return 42
def __getitem__(self, idx):
img = torch.zeros(1, 28, 28)
img[:28//2] = 1
label = idx // 10
return img, label
d = TestDataset()
d = l2l.data.MetaDataset(d)
transforms = [
NWays(d, 2), # Samples N random classes per task (here, N = 5)
# Samples K samples per class from the above N classes (here, K = 1)
KShots(d, 1),
LoadData(d), # Loads a sample from the dataset
RemapLabels(d), # Remaps labels starting from zero
# Re-orders samples s.t. they are sorted in consecutive order
ConsecutiveLabels(d),
# Randomly rotate sample over x degrees (only for vision tasks)
RandomClassRotation(d, [0, 90, 180, 270])
]
taskset = l2l.data.TaskDataset(d, transforms)
X, y = taskset.sample()
print(X.shape)
Description
The
vision.transforms.RandomClassRotation
does not work as expected . Thetorch.utils.data.dataset.Dataset
returns in the__getitem__
anumpy.ndarray
and aint
To Reproduce