huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.29k stars 2.7k forks source link

Disable warning when using with_format format on tensors #7088

Open Haislich opened 3 months ago

Haislich commented 3 months ago

Feature request

If we write this code:

"""Get data and define datasets."""

from enum import StrEnum
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms

class Split(StrEnum):
    """Describes what type of split to use in the dataloader"""

    TRAIN = "train"
    TEST = "test"
    VAL = "validation"

class ImageNetDataLoader(DataLoader):
    """Create an ImageNetDataloader"""

    _preprocess_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
    )

    def __init__(self, batch_size: int = 4, split: Split = Split.TRAIN):
        dataset = (
            load_dataset(
                "imagenet-1k",
                split=split,
                trust_remote_code=True,
                streaming=True,
            )
            .with_format("torch")
            .map(self._preprocess)
        )

        super().__init__(dataset=dataset, batch_size=batch_size)

    def _preprocess(self, data):
        if data["image"].shape[0] < 3:
            data["image"] = data["image"].repeat(3, 1, 1)
        data["image"] = self._preprocess_transform(data["image"].float())
        return data

if __name__ == "__main__":

    dataloader = ImageNetDataLoader(batch_size=2)
    for batch in dataloader:
        print(batch["image"])
        break

This will trigger an user warning :

datasets\formatting\torch_formatter.py:85: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

Motivation

This happens because the the way the formatted tensor is returned in TorchFormatter._tensorize. This function handle values of different types, according to some tests it seems that possible value types are int, numpy.ndarray and torch.Tensor. In particular this warning is triggered when the value type is torch.Tensor, because is not the suggested Pytorch way of doing it:

Your contribution

A solution that I found to be working is to change the current way of doing it:

return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

To:

if (isinstance(value, torch.Tensor)):
    tensor = value.clone().detach()
    if self.torch_tensor_kwargs.get('requires_grad', False): 
        tensor.requires_grad_()
    return tensor
else:
    return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})