Open VikaasVarma opened 7 months ago
I'm experimenting similar issues when loading image/text pairs (local). The RAM usage starts to increase non-stop (GPU is stable). I managed to "solve" it partially decreasing the number of workers in the data loader (my max CPU per node is 16vCPUs each node has 2 GPUs), so I set num_workers=8
and disabled persistent_workers
, I guess if I leave persistent_workers=True
the training will crash eventually
(training happens after downloading the data locally to the nodes)
maybe streaming is not cleaning up some states leading to that memory accumulation?
Hey y’all, thanks for bringing this issue to our attention. We’re looking into this and will get back to you soon.
Skimmed through the blog and PyTorch issue, is this an issue particular to Streaming or is it on the PyTorch side? StreamingDataLoader is a simple (stateful) subclass of PyTorch’s DataLoader. Does this also happen with other Datasets? @VikaasVarma @miguelalba96
So Streaming is designed for fast random sample access, from shards that live on disk. Samples, outside of dataloader prefetching, are never kept in memory. We conserve RAM to do other things though, like sample partitioning and shuffling, but this happens at the start of training. So I'm inclined to think that this is a PyTorch DataLoader issue, given the links you sent as well.
To track memory usage, maybe you could call gc.get_referents()
in a loop to track mem statistics? You might be able to find the memory issues by looking at gc
.
This does not happen outside of Mosaic or with other datasets. Using PyTorch's Dataloader instead of the StreamingDataLoader also leads to a memory leak. When pulling the data using PyTorch's torchdata to construct the dataset, there is no significant memory overhead.
I don't think the problem lies within the StreamingDataloader. The links seem to point towards large lists of python objects generally causing this issue. There are a few cases of this in the StreamingDataset (the stored shards, spanner's, stream filepaths, etc...).
To echo @VikaasVarma's point here-- the copy on read issues with the torch dataloader comes back to the dataset object storing a large number of naive python objects that can't use shared memory. I noticed that most of the dataset metadata is in fact in shared memory except for self.shards
which is a list of Reader
objects.
With smaller datasets, we never ran into this issue (or it never came up through the training lifecycle). We're only running into it now with a dataset that is a few orders of magnitude larger (more rows + larger row size and thus more shards).
If this is a copy-on-read issue, the memory wouldn't grow by a factor of row size, only the number of shards which I think is the case.
@VikaasVarma Is this a typo in your repro script?
def __getitem__(self, at: int) -> Sample:
obj = super().__getitem__(at)
_input = self.transforms(obj[self.input_key])
cond = torch.tensor(obj[self.cond_key])
if torch.rand(1) < self.cond_dropout_rate:
cond = torch.zeros_like(cond)
return inputs, cond. # should be _input?
To clarify, you expected "cond" to be gc by streamingdataset?
@VikaasVarma can you clarify your plot a bit? e.g., what does pss uss shared mean? and x, y axis. Can you also provide a sample dataset so I can reproduce the plot? Thanks!
I tested again training for longer period:
my implementation:
from ast import literal_eval
import torch
from PIL import Image
from streaming import StreamingDataset, StreamingDataLoader
import utils.visual_attribution # <- some of my modules for normal string maniputation
class ImageCaptionDataset(StreamingDataset):
def __init__(
self,
local: str,
shuffle: bool,
batch_size: int,
transformations: Callable,
) -> None:
super().__init__(
local=local, shuffle=shuffle, batch_size=batch_size,
)
self.transformations = transformations
@staticmethod
def get_zero_shot_one_hot(zero_shot_attributes: List[int]):
one_hot_encoded = torch.zeros(len(utils.visual_attribution.VISUAL_CLASSES), dtype=torch.float)
one_hot_encoded[zero_shot_attributes] = 1.0
return one_hot_encoded
def __getitem__(self, idx: int) -> Any:
obj = super().__getitem__(idx)
image = Image.open(io.BytesIO(obj["image"]))
caption = utils.visual_attribution.replace_article_type(obj["caption_simple"]) # does string replacement
zero_shot_attr = self.get_zero_shot_one_hot(literal_eval(obj["zero_shot_attributes"]))
return self.transformations(image), caption, zero_shot_attr
I call it like this:
def get_image_transformation_func(split: str):
transformations = []
if split == "train":
transformations += [
transforms.RandomHorizontalFlip(p=0.5),
# v2.RandomVerticalFlip(p=0.5)
]
transformations += [
transforms.Lambda(lambda x: x)
]
return transforms.Compose(transformations)
def collate_fn(batch, processor, tokenizer, max_length: int = 77):
# samples come from the dataset as CxHxW
images = processor(
images=[ex[0] for ex in batch],
return_tensors="pt"
)
captions = tokenizer(
[ex[1] for ex in batch],
padding="max_length",
max_length=max_length,
return_tensors="pt"
)
return {
# "pixel_values": torch.stack([ex[0] for ex in batch]),
"pixel_values": images["pixel_values"],
"input_ids": captions["input_ids"],
"attention_mask": captions["attention_mask"],
"labels": torch.stack([ex[2] for ex in batch])
}
def get_dataloader(split: str, config: configs.ExperimentConfig):
transform_func = get_image_transformation_func(split)
dataset = ImageCaptionDataset(
local=os.path.join(config.local_data_path, split),
shuffle=True if split == "train" else False,
batch_size=config.dataset_config.batch_size,
transformations=transform_func,
)
return StreamingDataLoader(
dataset,
batch_size=config.dataset_config.batch_size,
num_workers=config.dataset_config.num_workers,
collate_fn=partial(
collate_fn,
processor=config.dataset_config.processor,
tokenizer=config.dataset_config.tokenizer),
drop_last=True,
pin_memory=config.dataset_config.pin_memory,
prefetch_factor=config.dataset_config.prefetch_factor,
# persistent_workers=config.dataset_config.persistent_workers
)
Encounter a similar issue, the CPU Memory usage keep increasing until OOM in about two hours.
Check this issue
https://github.com/mosaicml/streaming/issues/758
I found a memory leak problem, it is in the work of the boto3 library
Encounter similar CPU memory leak issue when training on H800.
Hey @huxuan @wanghao14 @miguelalba96, As @nagadit mentioned, there seems to be a memory leak with boto3 as detailed in boto/boto3#1670. If you are using boto3 / s3, can you verify if this is causing your problems?
@snarayan21 There is no boto3 in my code.
@snarayan21 I am using 1.5TB of images stored in shards locally in 4 nodes, each with an entire copy of the data, so technically I am not streaming
Environment
To reproduce
Steps to reproduce the behavior:
When using the
StreamingDataloader
(or the vanilla pytorchDataloader
) withnum_workers>0
, the processes slowly take more and more memory until the CPU RAM is filled.Expected behavior
The dataloader should be able to provide samples indefinitely without using a significant portion of available RAM.
Additional context
Below is the dataset and dataloader implementation. Each sample is roughly 10 MB. With 16 workers, a prefetch factor of 4, and a batch size of 32, the total memory usage should be, at max, 20 GB. The dataset is made up of around 1.3 million shards.
A similar problem seems to be documented in an issue and a blog post. I have recreated the graphs found in the blog post below.