learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

vision.transforms.RandomClassRotation not working #378

Closed joshuasv closed 1 year ago

joshuasv commented 1 year ago

Description

The vision.transforms.RandomClassRotation does not work as expected . The torch.utils.data.dataset.Dataset returns in the __getitem__ a numpy.ndarray and a int

To Reproduce

import numpy as np
import learn2learn as l2l
from torch.utils.data import dataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
from learn2learn.vision.transforms import RandomClassRotation

class TestDataset(dataset.Dataset):

    def __init__(self, mode="train"):
        super().__init__()

    def __len__(self):
        return 42

    def __getitem__(self, idx):
        img = np.zeros((28, 28))
        img[:28//2] = 1
        label = 0

        return img, label

d = TestDataset()
d = l2l.data.MetaDataset(d)

transforms = [
    NWays(d, 2),  # Samples N random classes per task (here, N = 5)
    KShots(d, 1), # Samples K samples per class from the above N classes (here, K = 1) 
    LoadData(d), # Loads a sample from the dataset
    RemapLabels(d), # Remaps labels starting from zero
    ConsecutiveLabels(d), # Re-orders samples s.t. they are sorted in consecutive order
    RandomClassRotation(d, [0, 90, 180, 270]) # Randomly rotate sample over x degrees (only for vision tasks)
]

X, y = taskset.sample()
# or, you can also sample this way:
for task in taskset:
    X, y = task
print(X.shape)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [114], in <cell line: 1>()
----> 1 X, y = taskset.sample()
      2 # or, you can also sample this way:
      3 for task in taskset:

File learn2learn/data/task_dataset.pyx:158, in learn2learn.data.task_dataset.CythonTaskDataset.sample()

File learn2learn/data/task_dataset.pyx:173, in learn2learn.data.task_dataset.CythonTaskDataset.__getitem__()

File learn2learn/data/task_dataset.pyx:142, in learn2learn.data.task_dataset.CythonTaskDataset.get_task()

File ~/.local/lib/python3.8/site-packages/learn2learn/vision/transforms.py:57, in RandomClassRotation.__call__.<locals>.<lambda>(x)
     49             rotations[c] = transforms.Compose(
     50                 [
     51                     transforms.ToPILImage(),
   (...)
     54                 ]
     55             )
     56     rotation = rotations[c]
---> 57     data_description.transforms.append(lambda x: (rotation(x[0]), x[1]))
     58 return task_description

File /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1186, in Module._call_impl(self, *input, **kwargs)
   1182 # If we don't have any hooks, we want to skip the rest of the logic in
   1183 # this function, and just call forward.
   1184 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1185         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1186     return forward_call(*input, **kwargs)
   1187 # Do not call functions when jit is used
   1188 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.8/site-packages/torchvision/transforms/transforms.py:1357, in RandomRotation.forward(self, img)
   1349 """
   1350 Args:
   1351     img (PIL Image or Tensor): Image to be rotated.
   (...)
   1354     PIL Image or Tensor: Rotated image.
   1355 """
   1356 fill = self.fill
-> 1357 channels, _, _ = F.get_dimensions(img)
   1358 if isinstance(img, Tensor):
   1359     if isinstance(fill, (int, float)):

File /opt/conda/lib/python3.8/site-packages/torchvision/transforms/functional.py:75, in get_dimensions(img)
     72 if isinstance(img, torch.Tensor):
     73     return F_t.get_dimensions(img)
---> 75 return F_pil.get_dimensions(img)

File /opt/conda/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:33, in get_dimensions(img)
     31     width, height = img.size
     32     return [channels, height, width]
---> 33 raise TypeError(f"Unexpected type {type(img)}")

TypeError: Unexpected type <class 'numpy.ndarray'>
seba-1511 commented 1 year ago

Hello @joshuasv,

The fix is for the dataset to return a torch.Tensor. You also need your images to have 3 dimensions. This should work:

import torch
import numpy as np
import learn2learn as l2l
from torch.utils.data import dataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
from learn2learn.vision.transforms import RandomClassRotation

class TestDataset(dataset.Dataset):

    def __init__(self, mode="train"):
        super().__init__()

    def __len__(self):
        return 42

    def __getitem__(self, idx):
        img = torch.zeros(1, 28, 28)
        img[:28//2] = 1
        label = idx // 10

        return img, label

d = TestDataset()
d = l2l.data.MetaDataset(d)

transforms = [
    NWays(d, 2),  # Samples N random classes per task (here, N = 5)
    # Samples K samples per class from the above N classes (here, K = 1)
    KShots(d, 1),
    LoadData(d),  # Loads a sample from the dataset
    RemapLabels(d),  # Remaps labels starting from zero
    # Re-orders samples s.t. they are sorted in consecutive order
    ConsecutiveLabels(d),
    # Randomly rotate sample over x degrees (only for vision tasks)
    RandomClassRotation(d, [0, 90, 180, 270])
]

taskset = l2l.data.TaskDataset(d, transforms)

X, y = taskset.sample()
print(X.shape)