Closed siemdejong closed 3 weeks ago
Thanks a lot for your PR!
I think we shouldn't modify sys.modules
as this also impacts code outside of lightly. Users might get really surprised if torchvision.transforms
is suddenly v2
instead of v1
because they imported lightly. I would prefer the approach where we define a new torchvision_transforms
variable. I would even put it in a new file in lightly/transforms/torchvision_transforms.py
.
I also noticed that the v2 ToTensor
transform is deprecated and shows a warning when used. We should replace it with a ToTensor
transform that doesn't show the warning. Something like this should work:
from typing import Callable
import torch
from torch import Tensor
from PIL.Image import Image as PILImage
try:
from torchvision.transforms import v2 as torchvision_transforms
_TRANSFORMS_V2 = True
except ImportError:
from torchvision import transforms as torchvision_transforms
_TRANSFORMS_V2 = False
def ToTensor() -> Callable[[PILImage], Tensor]:
T = torchvision_transforms
if _TRANSFORMS_V2:
# v2.transforms.ToTensor is deprecated and will be removed in the future.
# This is the new recommended way to convert a PIL Image to a tensor:
return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
else:
return T.ToTensor()
Resolves #1547
It deviates a little from the proposed method by @guarin in that lightly makes the importer think
torchvision.transforms
istorchvision.transforms.v2
. I chose for this option as this would only require a change inlightly.transforms.__init__.py
.