"""Get data and define datasets."""
from enum import StrEnum
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
class Split(StrEnum):
"""Describes what type of split to use in the dataloader"""
TRAIN = "train"
TEST = "test"
VAL = "validation"
class ImageNetDataLoader(DataLoader):
"""Create an ImageNetDataloader"""
_preprocess_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
]
)
def __init__(self, batch_size: int = 4, split: Split = Split.TRAIN):
dataset = (
load_dataset(
"imagenet-1k",
split=split,
trust_remote_code=True,
streaming=True,
)
.with_format("torch")
.map(self._preprocess)
)
super().__init__(dataset=dataset, batch_size=batch_size)
def _preprocess(self, data):
if data["image"].shape[0] < 3:
data["image"] = data["image"].repeat(3, 1, 1)
data["image"] = self._preprocess_transform(data["image"].float())
return data
if __name__ == "__main__":
dataloader = ImageNetDataLoader(batch_size=2)
for batch in dataloader:
print(batch["image"])
break
This will trigger an user warning :
datasets\formatting\torch_formatter.py:85: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
Motivation
This happens because the the way the formatted tensor is returned in TorchFormatter._tensorize.
This function handle values of different types, according to some tests it seems that possible value types are int, numpy.ndarray and torch.Tensor.
In particular this warning is triggered when the value type is torch.Tensor, because is not the suggested Pytorch way of doing it:
Feature request
If we write this code:
This will trigger an user warning :
Motivation
This happens because the the way the formatted tensor is returned in
TorchFormatter._tensorize
. This function handle values of different types, according to some tests it seems that possible value types areint
,numpy.ndarray
andtorch.Tensor
. In particular this warning is triggered when the value type istorch.Tensor
, because is not the suggested Pytorch way of doing it:Your contribution
A solution that I found to be working is to change the current way of doing it:
To: