Closed juliendenize closed 4 months ago
Since the release of Pytorch 2.3.0, using pin_memory = True in the dataloaders remove the typing of Torchaug tensors which is impractical.
pin_memory = True
import torch from torchaug.data.dataloader import default_collate from torchaug.ta_tensors import Image from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self): super().__init__() def __getitem__(self, index): return Image(torch.rand(1, 3, 224, 224)) def __len__(self): return 10 dataset = MyDataset() dataloader = DataLoader(dataset, batch_size=2, collate_fn=default_collate, pin_memory=True) batch = next(iter(dataloader)) print(type(batch)) >>> torch.Tensor dataloader = DataLoader(dataset, batch_size=2, collate_fn=default_collate, pin_memory=False) batch = next(iter(dataloader)) print(type(batch)) >>> torchaug.ta_tensors._batch_images.BatchImages
pin_memory
Context
Since the release of Pytorch 2.3.0, using
pin_memory = True
in the dataloaders remove the typing of Torchaug tensors which is impractical.Reproduce
Suggestion for fix
pin_memory
: short-term solution which can cause slow downs