juliendenize / torchaug

Library to perform efficient vision data augmentations for CPU/GPU per-sample/batched data.
https://torchaug.readthedocs.io/en/latest/
Other
24 stars 2 forks source link

Pin memory in dataloaders remove the typing of tensors #49

Closed juliendenize closed 4 months ago

juliendenize commented 4 months ago

Context

Since the release of Pytorch 2.3.0, using pin_memory = True in the dataloaders remove the typing of Torchaug tensors which is impractical.

Reproduce

import torch
from torchaug.data.dataloader import default_collate
from torchaug.ta_tensors import Image
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        super().__init__()

    def __getitem__(self, index):
        return Image(torch.rand(1, 3, 224, 224))

    def __len__(self):
        return 10

dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=default_collate, pin_memory=True)

batch = next(iter(dataloader))
print(type(batch))
>>> torch.Tensor

dataloader = DataLoader(dataset, batch_size=2, collate_fn=default_collate, pin_memory=False)

batch = next(iter(dataloader))
print(type(batch))
>>> torchaug.ta_tensors._batch_images.BatchImages

Suggestion for fix

  1. Until it is fixed remove the pin_memory: short-term solution which can cause slow downs
  2. Rewrite partly the dataloader and expose it in Torchaug
  3. Recast the types of the tensors after collation which forces metadata to be stored for several Torchaug tensors (eg: number of masks per samples)