lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.83k stars 246 forks source link

Torchvision transforms v2 #1555

Closed siemdejong closed 3 weeks ago

siemdejong commented 4 weeks ago

Resolves #1547

It deviates a little from the proposed method by @guarin in that lightly makes the importer think torchvision.transforms is torchvision.transforms.v2. I chose for this option as this would only require a change in lightly.transforms.__init__.py.

guarin commented 3 weeks ago

Thanks a lot for your PR!

I think we shouldn't modify sys.modules as this also impacts code outside of lightly. Users might get really surprised if torchvision.transforms is suddenly v2 instead of v1 because they imported lightly. I would prefer the approach where we define a new torchvision_transforms variable. I would even put it in a new file in lightly/transforms/torchvision_transforms.py.

I also noticed that the v2 ToTensor transform is deprecated and shows a warning when used. We should replace it with a ToTensor transform that doesn't show the warning. Something like this should work:

from typing import Callable
import torch
from torch import Tensor
from PIL.Image import Image as PILImage

try:
    from torchvision.transforms import v2 as torchvision_transforms
    _TRANSFORMS_V2 = True
except ImportError:
    from torchvision import transforms as torchvision_transforms
    _TRANSFORMS_V2 = False

def ToTensor() -> Callable[[PILImage], Tensor]:
    T = torchvision_transforms
    if _TRANSFORMS_V2:
        # v2.transforms.ToTensor is deprecated and will be removed in the future.
        # This is the new recommended way to convert a PIL Image to a tensor:
        return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
    else:
        return T.ToTensor()