mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training
https://streaming.docs.mosaicml.com
Apache License 2.0
1.07k stars 136 forks source link

Out of Memory when using Streaming Dataloader #652

Open VikaasVarma opened 4 months ago

VikaasVarma commented 4 months ago

Environment

To reproduce

Steps to reproduce the behavior:

When using the StreamingDataloader (or the vanilla pytorch Dataloader) with num_workers>0, the processes slowly take more and more memory until the CPU RAM is filled.

Expected behavior

The dataloader should be able to provide samples indefinitely without using a significant portion of available RAM.

Additional context

Below is the dataset and dataloader implementation. Each sample is roughly 10 MB. With 16 workers, a prefetch factor of 4, and a batch size of 32, the total memory usage should be, at max, 20 GB. The dataset is made up of around 1.3 million shards.

A similar problem seems to be documented in an issue and a blog post. I have recreated the graphs found in the blog post below.

class ImageTokenDataset(StreamingDataset):
    def __init__(
        self,
        remote: str,
        batch_size: int,
        shuffle: bool = False,
        local: str | None = None,
        split: str | None = None,
        transforms: T.Compose = T.Compose([T.ToImage(), T.ToDtype(torch.float32)]),
        input_key: str = "jpg",
        cond_key: str = "cond",
        cond_dropout_rate: float = 0.5,
        predownload: int | None = None,
        **kwargs,
    ) -> None:
        super().__init__(
            local=local,
            remote=remote,
            shuffle=shuffle,
            batch_size=batch_size,
            split=split,
            predownload=predownload,
            **kwargs,
        )

        self.batch_size = batch_size
        self.transforms = transforms
        self.input_key = input_key
        self.cond_key = cond_key
        self.cond_dropout_rate = cond_dropout_rate

    def __getitem__(self, at: int) -> Sample:
        obj = super().__getitem__(at)

        _input = self.transforms(obj[self.input_key])
        cond = torch.tensor(obj[self.cond_key])

        if torch.rand(1) < self.cond_dropout_rate:
            cond = torch.zeros_like(cond)

        return inputs, cond

    def to_dataloader(
        num_workers: int = 8,
        prefetch_factor: int | None = None,
        persistent_workers: bool = True,
        pin_memory: bool = True,
        drop_last: bool = True,
        batch_size: int | None = None,
    ):
        return StreamingDataLoader(
            self,
            batch_size=batch_size or self.batch_size,
            drop_last=drop_last,
            prefetch_factor=prefetch_factor,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory,
        )

if __name__ == "__main__":
    dataset = ImageTokenDataset(
        remote=remote_path,
        batch_size=32
        local="/tmp/dataset/train",
        split="train",
        input_key="jpg",
        cond_key="t5",
        cond_dropout_rate=0.5,
    )
    dataloader = dataset.to_dataloader(
        num_workers=16, persistent_workers=True, pin_memory=False, prefetch_factor=4
    )

    for _ in tqdm(dataloader):
        pass

503128

miguelalba96 commented 4 months ago

I'm experimenting similar issues when loading image/text pairs (local). The RAM usage starts to increase non-stop (GPU is stable). I managed to "solve" it partially decreasing the number of workers in the data loader (my max CPU per node is 16vCPUs each node has 2 GPUs), so I set num_workers=8 and disabled persistent_workers, I guess if I leave persistent_workers=True the training will crash eventually

(training happens after downloading the data locally to the nodes)

image

maybe streaming is not cleaning up some states leading to that memory accumulation?

snarayan21 commented 4 months ago

Hey y’all, thanks for bringing this issue to our attention. We’re looking into this and will get back to you soon.

snarayan21 commented 4 months ago

Skimmed through the blog and PyTorch issue, is this an issue particular to Streaming or is it on the PyTorch side? StreamingDataLoader is a simple (stateful) subclass of PyTorch’s DataLoader. Does this also happen with other Datasets? @VikaasVarma @miguelalba96

snarayan21 commented 4 months ago

So Streaming is designed for fast random sample access, from shards that live on disk. Samples, outside of dataloader prefetching, are never kept in memory. We conserve RAM to do other things though, like sample partitioning and shuffling, but this happens at the start of training. So I'm inclined to think that this is a PyTorch DataLoader issue, given the links you sent as well.

To track memory usage, maybe you could call gc.get_referents() in a loop to track mem statistics? You might be able to find the memory issues by looking at gc.

VikaasVarma commented 4 months ago

This does not happen outside of Mosaic or with other datasets. Using PyTorch's Dataloader instead of the StreamingDataLoader also leads to a memory leak. When pulling the data using PyTorch's torchdata to construct the dataset, there is no significant memory overhead.

I don't think the problem lies within the StreamingDataloader. The links seem to point towards large lists of python objects generally causing this issue. There are a few cases of this in the StreamingDataset (the stored shards, spanner's, stream filepaths, etc...).

tonyf commented 4 months ago

To echo @VikaasVarma's point here-- the copy on read issues with the torch dataloader comes back to the dataset object storing a large number of naive python objects that can't use shared memory. I noticed that most of the dataset metadata is in fact in shared memory except for self.shards which is a list of Reader objects.

With smaller datasets, we never ran into this issue (or it never came up through the training lifecycle). We're only running into it now with a dataset that is a few orders of magnitude larger (more rows + larger row size and thus more shards).

If this is a copy-on-read issue, the memory wouldn't grow by a factor of row size, only the number of shards which I think is the case.

XiaohanZhangCMU commented 4 months ago

@VikaasVarma Is this a typo in your repro script?

def __getitem__(self, at: int) -> Sample:
    obj = super().__getitem__(at)

    _input = self.transforms(obj[self.input_key])
    cond = torch.tensor(obj[self.cond_key])

    if torch.rand(1) < self.cond_dropout_rate:
        cond = torch.zeros_like(cond) 

    return inputs, cond. # should be _input? 

To clarify, you expected "cond" to be gc by streamingdataset?

XiaohanZhangCMU commented 4 months ago

@VikaasVarma can you clarify your plot a bit? e.g., what does pss uss shared mean? and x, y axis. Can you also provide a sample dataset so I can reproduce the plot? Thanks!

miguelalba96 commented 4 months ago

I tested again training for longer period:

image

my implementation:

from ast import literal_eval

import torch

from PIL import Image
from streaming import StreamingDataset, StreamingDataLoader

import utils.visual_attribution # <- some of my modules for normal string maniputation

class ImageCaptionDataset(StreamingDataset):
    def __init__(
            self,
            local: str,
            shuffle: bool,
            batch_size: int,
            transformations: Callable,
    ) -> None:
        super().__init__(
            local=local, shuffle=shuffle, batch_size=batch_size,
        )
        self.transformations = transformations

    @staticmethod
    def get_zero_shot_one_hot(zero_shot_attributes: List[int]):
        one_hot_encoded = torch.zeros(len(utils.visual_attribution.VISUAL_CLASSES), dtype=torch.float)
        one_hot_encoded[zero_shot_attributes] = 1.0
        return one_hot_encoded

    def __getitem__(self, idx: int) -> Any:
        obj = super().__getitem__(idx)
        image = Image.open(io.BytesIO(obj["image"]))
        caption = utils.visual_attribution.replace_article_type(obj["caption_simple"]) # does string replacement
        zero_shot_attr = self.get_zero_shot_one_hot(literal_eval(obj["zero_shot_attributes"]))
        return self.transformations(image), caption, zero_shot_attr

I call it like this:


def get_image_transformation_func(split: str):
    transformations = []
    if split == "train":
        transformations += [
            transforms.RandomHorizontalFlip(p=0.5),
            # v2.RandomVerticalFlip(p=0.5)
        ]
    transformations += [
        transforms.Lambda(lambda x: x)
    ]
    return transforms.Compose(transformations)

def collate_fn(batch, processor, tokenizer, max_length: int = 77):
    # samples come from the dataset as CxHxW
    images = processor(
        images=[ex[0] for ex in batch],
        return_tensors="pt"
    )
    captions = tokenizer(
            [ex[1] for ex in batch],
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
    return {
            # "pixel_values": torch.stack([ex[0] for ex in batch]),
            "pixel_values": images["pixel_values"],
            "input_ids": captions["input_ids"],
            "attention_mask": captions["attention_mask"],
            "labels": torch.stack([ex[2] for ex in batch])
    }

def get_dataloader(split: str, config: configs.ExperimentConfig):
    transform_func = get_image_transformation_func(split)
    dataset = ImageCaptionDataset(
        local=os.path.join(config.local_data_path, split),
        shuffle=True if split == "train" else False,
        batch_size=config.dataset_config.batch_size,
        transformations=transform_func,
    )
    return StreamingDataLoader(
        dataset,
        batch_size=config.dataset_config.batch_size,
        num_workers=config.dataset_config.num_workers,
        collate_fn=partial(
            collate_fn,
            processor=config.dataset_config.processor,
            tokenizer=config.dataset_config.tokenizer),
        drop_last=True,
        pin_memory=config.dataset_config.pin_memory,
        prefetch_factor=config.dataset_config.prefetch_factor,
        # persistent_workers=config.dataset_config.persistent_workers
    )
huxuan commented 1 month ago

Encounter a similar issue, the CPU Memory usage keep increasing until OOM in about two hours.

nagadit commented 14 hours ago

Check this issue

https://github.com/mosaicml/streaming/issues/758

I found a memory leak problem, it is in the work of the boto3 library