uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.76k stars 281 forks source link

PyTorch Batched Non-shuffle Buffer Large Memory Consumption #763

Closed chongxiaoc closed 1 year ago

chongxiaoc commented 1 year ago

Batched non-shuffled buffer keeps creating copies of existing row group when making every single batch. _make_batch() is called for creating every single batch.

This leads to large memory consumption issues when generating batches inside a very large row group. We use the below scripts to test, and profile memory usage:

from petastorm import make_batch_reader
from petastorm.pytorch import BatchedDataLoader

reader_factory = make_batch_reader
feature_cols = ['your_feature_cols...']
label_cols = ['your_label_cols...']
batch_size = 8192

data_path = "some_test_path"
train_reader = reader_factory(data_path,
                              num_epochs=None, cur_shard=0,
                              shard_count=1, hdfs_driver="libhdfs",
                              schema_fields=feature_cols+label_cols,
                              shuffle_row_groups=False)

train_dl = BatchedDataLoader(train_reader, batch_size=batch_size, shuffling_queue_capacity=0)

for batch in train_dl:
    print(batch)

Memory profiler is injected into https://github.com/uber/petastorm/blob/fa8a8812e49f9a0fa604f557a0513bcf4d74281e/petastorm/reader_impl/pytorch_shuffling_buffer.py#L98 for example:

diff --git a/petastorm/reader_impl/pytorch_shuffling_buffer.py b/petastorm/reader_impl/pytorch_shuffling_buffer.py
index 6dac8a4..aac86d7 100644
--- a/petastorm/reader_impl/pytorch_shuffling_buffer.py
+++ b/petastorm/reader_impl/pytorch_shuffling_buffer.py
@@ -106,6 +106,11 @@ class BatchedNoopShufflingBuffer(BatchedShufflingBufferBase):
             self._batches = []
         self._num_samples -= min(self._num_samples, self.batch_size)
         self.store.append(batch)
+        import psutil

+        print("used mem:",  psutil.virtual_memory())

The result for batch_size = 8192 and a 20M rows per rowgroup, 50 double type columns parquet file on HDFS is as below:

Screen Shot 2022-07-12 at 10 24 32 PM

With applying PR #762 for reusing the same row-group buffer, we can bring down the memory usage.

chongxiaoc commented 1 year ago

@selitvin FYI