Open clemsgrs opened 9 months ago
I ran into a similar issue and it really comes down to your devices computational power. One solution if you are running out of memory that I initially used is to write to the Hard Drive and back having it act as pseudo ram in a sense. This is incredibly slow though. If working with that large of a dataset I highly recommend offloading your processing onto a supercluster if you have the ability to do so. If you are simply trying to create and load the dataset I would recommend using hugging face to store and load the dataset in pieces, do not load all at once but rather in batches. Please let me know if I anything I said prior is not clear or not applicable to your situation.
@clemsgrs Did you manage to find a solution to your issues ? I am facing similar problems.
Hi, I'm using DINOv2 to pretrain a ViT on a dataset significantly larger than ImageNet22k (between 100M and 1B jpg images). I sticked to the ImageNet22k dataset class for handling and loading data, i.e. utilizing a combination of tarball files for storing images and a single npy file for metadata (start and end offsets + information to know in which tarball file a given image is located). I put the code snippet below.
Unfortunately, I am facing very slow data loading times:
1) Large tarball files: some tarballs I work with containing as many as 6M images. I suspect this increases RAM usage, which could explain the to slow data loading times -- or even out-of-memory errors -- I face.
2) To mitigate this issue, I split the large tarballs into smaller ones (of 1Gb). Despite offering some relief by reducing the memory footprint during data loading, this solution doesn't scale well with the batch size : the bigger the batch size, the more tarball files to open/close concurrently, which seems to add significant overhead as it slows the data loading process.
I've tried looking into alternative tools (WebDataset, TorchData), but wasn't successful. I am therefore reaching out for any advice, or alternative strategies to handle large-scale vision datasets. Thank you!
Dataset code
```python import numpy as np from io import BytesIO from typing import Any from PIL import Image from pathlib import Path from mmap import ACCESS_READ, mmap from typing import Any, Callable, Optional, Tuple from torchvision.datasets import VisionDataset from functools import lru_cache class Decoder: def decode(self) -> Any: raise NotImplementedError class ImageDataDecoder(Decoder): def __init__(self, image_data: bytes) -> None: self._image_data = image_data def decode(self) -> Image: f = BytesIO(self._image_data) return Image.open(f).convert(mode="RGB") class TargetDecoder(Decoder): def __init__(self, target: Any): self._target = target def decode(self) -> Any: return self._target _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors def _get_tarball_path(dataset_name: str) -> str: return f"{dataset_name}.tar" def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): @lru_cache(maxsize=mmap_cache_size) def _mmap_tarball(dataset_name: str) -> mmap: tarball_path = _get_tarball_path(dataset_name) tarball_full_path = Path(tarballs_root, tarball_path) with open(tarball_full_path) as f: return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) return _mmap_tarball class FoundationDataset(VisionDataset): def __init__( self, *, root: str, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, ) -> None: super().__init__(root, transforms, transform, target_transform) self._get_entries() self._get_dataset_names() self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) @property def _tarballs_root(self) -> str: return self.root @property def _entries_name(self) -> str: return "pretrain_entries.npy" def _get_entries(self) -> np.ndarray: self._entries = self._load_entries(self._entries_name) def _load_entries(self, _entries_name: str) -> np.ndarray: entries_path = Path(self.root, _entries_name) return np.load(entries_path, mmap_mode="r") def _get_filepaths_dict(self, dataset_name: str): return self._load_filepaths_dict(dataset_name) def _load_filepaths_dict(self, dataset_name: str): filepaths_dict_path = Path(self.root, f"{dataset_name}_file_indices.npy") return np.load(filepaths_dict_path, allow_pickle=True).item() def _get_dataset_names(self) -> dict: self._dataset_names = self._load_dataset_names() def _load_dataset_names(self) -> dict: dataset_dict_path = Path(self.root, "dataset_indices.npy") return np.load(dataset_dict_path, allow_pickle=True).item() def get_image_data(self, index: int) -> bytes: entry = self._entries[index] file_idx, start_offset, end_offset, dataset_idx = ( entry[1], entry[2], entry[3], entry[4], ) dataset_name = self._dataset_names[dataset_idx] filepaths_dict = self._get_filepaths_dict(dataset_name) filepath = filepaths_dict[file_idx] class_mmap = self._mmap_tarball(dataset_name) data = class_mmap[start_offset:end_offset] return data, Path(filepath) def get_target(self, index: int) -> Any: return int(self._entries[index][0]) def get_targets(self) -> np.ndarray: return self._entries[:, 0] def __getitem__(self, index: int) -> Tuple[Any, Any]: try: image_data, _ = self.get_image_data(index) image = ImageDataDecoder(image_data).decode() except Exception as e: raise RuntimeError(f"can not read image for sample {index} ({e})") from e target = self.get_target(index) target = TargetDecoder(target).decode() if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self) -> int: return len(self._entries) ```