Open apsdehal opened 2 years ago
Are you sure there is a leak? How can I see it? You shared the script but not the output which you believe should indicate a leak.
I modified your reproduction script to print only once per try as your original was printing too much info and you absolutely must add gc.collect()
when doing any memory measurements, since python's GC is scheduled so you might be measuring the wrong thing. This gives us:
import psutil
import os
import gc
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_TRIES = 100
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
x.pop("text")
x.pop("label")
return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
for idx, batch in enumerate(train_loader): pass
gc.collect()
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(count, mem_after - mem_before)
count += 1
Now running it:
$ python dl-leak.py
Reusing dataset imdb (/home/stas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
0 4.43359375
1 4.4453125
2 4.44921875
3 4.44921875
4 4.4609375
5 4.46484375
6 4.46484375
7 4.46484375
8 4.46484375
9 4.46484375
10 4.46484375
11 4.46484375
12 4.46484375
13 4.46484375
14 4.46484375
15 4.46484375
16 4.46484375
It's normal that at the beginning there is a small growth in memory usage, but after 5 cycles it gets steady.
Unless of course you're referring the memory growth during the first try. Is that what you're referring to? And since your ds is small it's hard to see the growth - could it be just because some records are longer and it needs to allocate more memory for those?
Though while experimenting with this I have observed a peculiar thing, if I concatenate 2 datasets, I don't see any growth at all. But that's probably because the program allocated additional peak RSS memory to concatenate and then is re-using the memory
I basically tried to see if I make the dataset much longer, I'd expect not to see any memory growth once the 780 records of the imdb ds have been processed once.
It is hard to say if it is directly reproducible in this setup. Maybe it is specific to the images stored in the CM4 case which cause a memory leak. I am still running your script and seeing if I can reproduce that particular leak in this case.
I was able to reproduce the leak with:
import psutil
import os
import gc
from datasets import load_from_disk
import time
DATASET_PATH = "/hf/m4-master/data/cm4/cm4-10000-v0.1"
dataset = load_from_disk(DATASET_PATH)
# truncate to a tiny dataset
dataset = dataset.select(range(1000))
print(f"dataset: {len(dataset)} records")
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, rec in enumerate(dataset):
if idx % 100 == 0:
gc.collect()
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
You need to adjust the DATASET_PATH record.
which you get from
gsutil -m cp "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset.arrow" "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset_info.json" "gs://hf-science-m4/cm4/cm4-10000-v0.1/state.json" .
(I assume the hf folks have the perms) - it's a smallish dataset (10k)
then you run:
$ python ds.py
dataset: 1000 records
0 1.0156MB
100 126.3906MB
200 142.8906MB
300 168.5586MB
400 218.3867MB
500 230.7070MB
600 238.9570MB
700 263.3789MB
800 288.1289MB
900 300.5039MB
you should be able to see the leak
This issue has nothing to do with PIL
's decoder. I removed it and the problem is still there.
I then traced this leak to this single call: pa_table.to_pydict()
here:
I can make it leak much faster by modifying that code to repeat pa_table.to_pydict()
many times in a row. It shouldn't have that impact:
class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]):
def extract_row(self, pa_table: pa.Table) -> dict:
x = [pa_table.to_pydict() for x in range(200)]
return _unnest(pa_table.to_pydict())
@lhoestq - do you know what might be happening inside pa_table.to_pydict()
, as this is in the pyarrow
domain. Perhaps you know someone to tag from that project?
Probably next need to remove datasets
from the equation and make a reproducible case with just pyarrow
directly.
The problem already happens with pyarrow==6.0.0
or later (minimum for current datasets
)
I'm also trying to dig in with objgraph
to see if there are any circular references which prevent objects from being freed, but no luck there so far. And I'm pretty sure to_pydict
is not a python code, so the problem is likely to happen somewhere outside of python's GC.
This appears to be the same issue I think: https://github.com/huggingface/datasets/issues/4528 I dug into the repro code there and it's the same behavior with the same leak, but it's a pure nlp dataset and thus much faster to work with.
I went all the way back to pyarrow==1.0.0
and datasets==1.12.0
and the problem is still there. How is it even possible that it wasn't noticed all this time.
Could it be that the leak is in some 3rd party component pyarrow
relies on? as while downgrading I have only downgraded the above 2 packages.
Also found this warning
Be careful: if you don't pass the ArrowArray struct to a consumer, array memory will leak. This is a low-level function intended for expert users.
perhaps something triggers this condition?
I have no idea if it's related - this is just something that came up during my research.
Does it crash with OOM at some point? If it doesn't, it isn't a leak, just agressive caching or a custom allocator that doesn't like to give memory back (not uncommon). #4528 looks like it hits a steady state.
I believe the underlying arrow libs use a custom C allocator. Some of those are designed not to give back to OS, but keep heap memory for themselves to re-use (hitting up the OS involves more expensive mutex locks, contention, etc). The greedy behaviour can be undesirable though. There are likely flags to change the allocator behaviour, and one could likely build without any custom allocators (or use a different one).
Does it crash with OOM at some point?
In the original setup where we noticed this problem, it was indeed ending in an OOM
https://github.com/huggingface/datasets/issues/4528 looks like it hits a steady state.
@rwightman in the plot I shared, the steady state comes from the time.sleep(100)
I added in the end of the script, to showcase that even the garbage collector couldn't free that allocated memory.
Could this be related to this discussion about a potential memory leak in pyarrow: https://issues.apache.org/jira/browse/ARROW-11007 ?
(Note: I've tried import pyarrow; pyarrow.jemalloc_set_decay_ms(0)
and the memory leak is still happening on your toy example)
@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?
to_pydict
calls to_pylist
on each column (i.e. on each PyArrow Array). Then it iterates on the array and calls as_py
on each element. The as_py
implementation depends on the data type. For strings I think it simply gets the buffer that contains the binary string data that is defined in C++
The Arrow team is pretty responsive at user@arrow.apache.org if it can help
Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.
That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?
That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?
I added you to the bucket @lhoestq
It looks like an issue with memory mapping:
keep_in_memory=True
in load_from_disk
loads the dataset in RAM, and doesn't cause any memory leakHere is a code to reproduce this issue using only PyArrow and a dummy arrow file:
import psutil
import os
import gc
import pyarrow as pa
import time
ARROW_PATH = "tmp.arrow"
if not os.path.exists(ARROW_PATH):
arr = pa.array([b"a" * (200 * 1024)] * 1000) # ~200MB
table = pa.table({"a": arr})
with open(ARROW_PATH, "wb") as f:
writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
writer.write_table(table)
writer.close()
def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
memory_mapped_stream = pa.memory_map(filename)
opened_stream = pa.ipc.open_stream(memory_mapped_stream)
pa_table = opened_stream.read_all()
return pa_table
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
prints
0 0.2500MB
100 19.8008MB
200 39.3320MB
300 58.8633MB
400 78.3945MB
500 97.9258MB
600 117.4570MB
700 136.9883MB
800 156.5195MB
900 176.0508MB
Note that this example simply iterates over the pyarrow.lib.BinaryScalar
objects in the array. Running .as_py()
is not needed to experience the memory issue.
@lhoestq that does indeed increase in memory, but if you iterate over array again after the first time, or re-open and remap the same file (repeat table = memory_mapped_arrow_table_from_file(ARROW_PATH)
) before re-iterating, it doesn't move pas 195MB.... it would appear another step is needed to continue consuming memory past that.. hmmm
Are the pa_tables held on to anywhere after they are iterated in the real code?
in my hack, if you do a bunch cut & paste and then change the arr name for each iter
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr1 = table[0]
for idx, x in enumerate(arr1):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr2 = table[0]
for idx, x in enumerate(arr2):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
it leaks, if all arr are the same name (so prev one gets cleaned up) it does not and goes back to 0, anything that could be holding onto a reference of an intermediary equivalent like arr in the real use case?
Yes, we have already established here https://github.com/huggingface/datasets/issues/4883#issuecomment-1232063891 that when one iterates over the whole dataset multiple times, it consumes a bit more memory in the next few repetitions and then remains steady.
Which means that when a new iterator is created over the same dataset, all the memory from the previous iterator is re-used.
So the leak happens primarily when the iterator is "drained" the first time. which tells me that either a circular reference is created somewhere which only gets released when the iterator is destroyed, or there is some global variable that keeps piling up the memory and doesn't release it in time.
Also I noticed some __del__
methods which won't destroy objects automatically and there is usually a warning against using it https://stackoverflow.com/a/1481512/9201239
There are also some weakref
s in the code which too may lead to leaks or weird problems at times.
@stas00 my point was, I'm not convinced @lhoestq last example illustrates the leak, but rather the differences between memory mapping and in memory usage patterns. If you destroy arr, memory map impl goes back to 0 each iteration. The amount of memory that 'looks' like it is leaked in first pass differes quite a bit between memory mapped vs in memory, but the underlying issue likely a circular reference, or reference(s) which were not cleaned up that would impact either case, but likely much more visible with mmap.
Thank you for clarifying, Ross.
I think we agree that it's almost certain that the datasets
iterator traps some inner variable that prevents object freeing, since if we create the iterator multiple times (and drain it) after a few runs no new memory is allocated. We could try to dig in more with objgraph
- my main concern is if the problem happens somewhere outside of python, (i.e. in pyarrow cpp implementation) in which case it'd be much more difficult to trace.
I wish there was a way on linux to tell the program to free no longer used memory at will.
FWIW, I revisted some code I had in the works to use HF datasets w/ timm train & val scripts. There is no leak there across multipe epochs. It uses the defaults.
It's worth noting that with imagenet keep_in_memory=True
isn't even an option because the train arrow file is ~140GB and my local memory is less. The virtual address space reflects mmap (> 150GB) and doesn't increase over epochs that I noticed. I have some perf issues to bring up wrt to the current setup, but that's a separate and lower prio discussion to have elsewhere...
After reading many issues and trying many things here is the summary of my learning
I'm now using @lhoestq repro case as it's pyarrow-isolated: https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985
it has 3 backends, I tried them all with the same results
pa.set_memory_pool(pa.jemalloc_memory_pool())
pa.set_memory_pool(pa.mimalloc_memory_pool())
pa.set_memory_pool(pa.system_memory_pool())
The jemalloc
backend supports quick release
pa.jemalloc_set_decay_ms(0)
it doesn't make any difference in this case
this is a useful tracer for PA memory allocators
pa.log_memory_allocations(enable=True)
it nicely reports memory allocations and releases when the arrow file is created the first time.
but when we then try to do enumerate(arr)
this logger reports 0 allocations.
This summary also reports no allocations when the script run the second time (post file creation):
mem_pool = pa.default_memory_pool()
print(f"PyArrow mem pool info: {mem_pool.backend_name} backend, {mem_pool.bytes_allocated()} allocated, "
f"{mem_pool.max_memory()} max allocated, ")
print(f"PyArrow total allocated bytes: {pa.total_allocated_bytes()}")
However it's easy to see by using tracemalloc
which only measures python's memory allocations that it's PA that leaks, since tracemalloc
shows fixed memory
(this is bolted on top of the original repro script)
import tracemalloc
tracemalloc.start()
[...]
for idx, x in enumerate(arr):
if idx % 10 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20
# pa.jemalloc_memory_pool().release_unused()
# pa.mimalloc_memory_pool().release_unused()
# pa.system_memory_pool().release_unused()
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {memory_mapped_stream.size()/2**20:4.4}MB {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")
gives:
0 5.4258MB 0.0110 0.0201 195.3MB 0.0MB 0.0MB
10 25.3672MB 0.0112 0.0202 195.3MB 0.0MB 0.0MB
20 45.9336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
30 62.4336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
40 83.0586MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
50 103.6836MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
60 124.3086MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
70 140.8086MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
80 161.4336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
90 182.0586MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
the 3rd and 4th columns are tracemalloc
's report.
the 5th column is the size of mmaped stream - fixed.
the last 2 are the PA's malloc reports - you can see it's totally fixed and 0.
So what gives? PA's memory allocator says nothing was allocated and we can see python doesn't allocate any memory either.
As someone suggested in one of the PA issues that IPC/GRPC could be the issue. Any suggestions on how debug this one?
The main issue is that one can't step through with a python debugger as arr
is an opaque cpp object binded to python.
Please see the next comment for a possible answer.
I also traced reference counts and they are all fixed using either sys.getrefcount(x)
or len(gc.get_referrers(x))
so it's not the python object
https://issues.apache.org/jira/browse/ARROW-11007 - looks very similar to our issue in particular this part of the report: https://issues.apache.org/jira/browse/ARROW-11007?focusedCommentId=17279642&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17279642
Next, lets revisit @rwightman's suggestion that there is actually no leak.
After all - we are using mmap which will try to map the file to RAM as much as it can and then page out if there is no memory. i.e. MMAP is only fast if you have a lot of CPU RAM.
So let's do it:
We first quickly start a cgroups-controlled shell which will instantly kill any program that consumes more than 1GB of memory:
$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=1G --setenv="MEMLIMIT=1GB" bash
Let's check that it indeed does so. Let's change @lhoestq's script to allocate a 10GB arrow file:
$ python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 5000)'
Killed
oops, that didn't work, as we tried to allocate 10GB when only 1GB is allowed. This is what we want!
Let's do a sanity check - can we allocate 0.1GB?
python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 50)'
Yes. So the limited shell does the right thing. It let's allocate < 1GB
of RSS RAM.
Next let's go back to @lhoestq's script but with 10GB arrow file.
we change his repro script https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985 to 50x larger file
arr = pa.array([b"a" * (2000 * 1024)] * 5000) # ~10000MB
we first have to run into a normal unlimited shell so that we don't get killed (as the script allocates 10GB)
let's run the script now in the 1GB-limited shell while running a monitor:
$ htop -F python -s M_RESIDENT -u `whoami`
so we have 2 sources of RSS info just in case.
$ python pyar.py
0 4.3516MB 0.0103 0.0194 9.766e+03MB 0.0MB 0.0MB
10 24.3008MB 0.0104 0.0195 9.766e+03MB 0.0MB 0.0MB
[...]
4980 9730.3672MB 0.0108 0.0199 9.766e+03MB 0.0MB 0.0MB
4990 9750.9922MB 0.0108 0.0199 9.766e+03MB 0.0MB 0.0MB
PyArrow mem pool info: jemalloc backend, 0 allocated, 0 max allocated,
PyArrow total allocated bytes: 0
But wait, it reported 10GB RSS both in htop
and in our log!
So that means it never allocated 10GB otherwise it'd have been killed.
Which tells us that there is no leak whatsoever and this is just a really difficult situation where MMAPPED memory is reported as part of RSS which it probably shouldn't. As now we have no way how to measure real memory usage.
I also attached the script with all the different things I have tried in it, so it should be easy to turn them on/off if you want to reproduce any of my findings.
just rename it to pyra.py
as gh doesn't let attaching scripts...
(I have to remember to exit that special mem-limited shell or else I won't be able to do anything serious there.)
The original leak in the multi-modal code is very likely something else. But of course now it'd be very difficult to trace it using mmap.
I think to debug we have to set keep_in_memory=True
in load_from_disk
to load the small dataset in RAM, so there will be no mmap misleading reporting component and then continue searching for another source of a leak.
To add to what @stas00 found, I'm gonna leave some links to where I believe the confusion came from in pyarrow's APIs, for future reference:
Arrow can directly reference the data mapped from disk and avoid having to allocate its own memory.
And where their example shows 0 RSS memory allocation, the way we used to measure RSS shows 39.6719MB allocated. Here's the script to reproduce:
import psutil
import os
import gc
import pyarrow as pa
import time
import sys
# gc.set_debug(gc.DEBUG_LEAK)
# gc.set_threshold(0,0,0)
#pa.set_memory_pool(pa.mimalloc_memory_pool())
#pa.set_memory_pool(pa.system_memory_pool())
import tracemalloc
#pa.jemalloc_set_decay_ms(0)
# pa.log_memory_allocations(enable=True)
BATCH_SIZE = 10000
NUM_BATCHES = 1000
schema = pa.schema([pa.field('nums', pa.int32())])
with pa.OSFile('bigfile.arrow', 'wb') as sink:
with pa.ipc.new_file(sink, schema) as writer:
for row in range(NUM_BATCHES):
batch = pa.record_batch([pa.array(range(BATCH_SIZE), type=pa.int32())], schema)
writer.write(batch)
start_use = pa.total_allocated_bytes()
pool = pa.default_memory_pool()
start_peak_use = pool.max_memory()
tracemalloc.start()
first_size, first_peak = tracemalloc.get_traced_memory()
mem_before = psutil.Process(os.getpid()).memory_info().rss / 2**20
# with pa.OSFile('bigfile.arrow', 'rb') as source:
# loaded_array = pa.ipc.open_file(source).read_all()
with pa.memory_map('bigfile.arrow', 'rb') as source:
loaded_array = pa.ipc.open_file(source).read_all()
print("LEN:", len(loaded_array))
print("RSS: {}MB".format(pa.total_allocated_bytes() >> 20))
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20
idx = 0
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")
gives:
LEN: 10000000
RSS: 0MB
0 39.6719MB 0.0132 0.0529 0.0MB 0.0MB
Which again just proves that we uncorrectly measure RSS, in the case of MMAPPED memory
@lhoestq, I have been working on a detailed article that shows that MMAP doesn't leak and it's mostly ready. I will share when it's ready.
The issue is that we still need to be able to debug memory leaks by turning MMAP off.
But, once I tried to show the user that using load_dataset(... keep_in_memory=True)
is the way to debug an actual memory leak - guess I what I discovered? A potential actual leak.
Here is the repro:
$ cat ds-mmap.py
from datasets import load_dataset
import gc
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']
print(f"{'idx':>6} {'RSS':>10} {'Δ RSS':>15}")
step = 20000
for i in range(0, 10*step, step):
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:6d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB ")
python ds-io.py
Reusing dataset wmt19 (/home/stas/.cache/huggingface/datasets/wmt19/cs-en/1.0.0/c3db1bf4240362ed1ef4673b354f468d70aac66d4e67d45f536d493a0840f0d3)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5.66it/s]
idx RSS Δ RSS
0 1398.4609MB 3.5195MB
20000 1398.5742MB 0.1133MB
40000 1398.6016MB 0.0273MB
60000 1398.6016MB 0.0000MB
80000 1398.6016MB 0.0000MB
100000 1398.6328MB 0.0312MB
120000 1398.6953MB 0.0625MB
140000 1398.6953MB 0.0000MB
160000 1398.7500MB 0.0547MB
180000 1398.7500MB 0.0000MB
as I suggested on slack perhaps it was due to dataset records length variation, so with your help I wrote another repro with synthetic records which are all identical - which should remove my hypothese from the equation and we should expect 0 incremental growth as we iterate over the datasets. But alas this is not the case. There is a tiny but definite leak-like behavior.
Here is the new repro:
$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
records = 1_000_000
print("Creating a synthetic dataset")
row = dict(foo=[dict(a='a'*500, b='b'*1000)])
ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
ds.save_to_disk(DS_PATH)
print("Done. Please restart the program")
sys.exit()
dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 25_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
if idx == warmup_iterations: # skip the first few iterations while things get set up
mem_start = mem_read()
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()
print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")
and the run:
$ python ds-synthetic-no-mmap.py
Dataset len=1000000
idx RSS Δ RSS
0 1601.9258MB 47.9688MB
25000 1641.6289MB 39.7031MB
50000 1641.8594MB 0.2305MB
75000 1642.1289MB 0.2695MB
100000 1642.1289MB 0.0000MB
125000 1642.3789MB 0.2500MB
150000 1642.3789MB 0.0000MB
175000 1642.6289MB 0.2500MB
200000 1642.6289MB 0.0000MB
225000 1642.8789MB 0.2500MB
250000 1642.8828MB 0.0039MB
275000 1643.1328MB 0.2500MB
300000 1643.1328MB 0.0000MB
325000 1643.3828MB 0.2500MB
350000 1643.3828MB 0.0000MB
375000 1643.6328MB 0.2500MB
400000 1643.6328MB 0.0000MB
425000 1643.8828MB 0.2500MB
450000 1643.8828MB 0.0000MB
475000 1644.1328MB 0.2500MB
500000 1644.1328MB 0.0000MB
525000 1644.3828MB 0.2500MB
550000 1644.3828MB 0.0000MB
575000 1644.6328MB 0.2500MB
600000 1644.6328MB 0.0000MB
625000 1644.8828MB 0.2500MB
650000 1644.8828MB 0.0000MB
675000 1645.1328MB 0.2500MB
700000 1645.1328MB 0.0000MB
725000 1645.3828MB 0.2500MB
750000 1645.3828MB 0.0000MB
775000 1645.6328MB 0.2500MB
800000 1645.6328MB 0.0000MB
825000 1645.8828MB 0.2500MB
850000 1645.8828MB 0.0000MB
875000 1646.1328MB 0.2500MB
900000 1646.1328MB 0.0000MB
925000 1646.3828MB 0.2500MB
950000 1646.3828MB 0.0000MB
975000 1646.6328MB 0.2500MB
Total diff: 4.5039MB (after 4 warmup iterations)
so I'm still not sure why we get this.
As you can see I started skipping the first few iterations where memory isn't stable yet. As the actual diff is much larger if we count all iterations.
What do you think?
@stas00 my 2 cents from having looked at a LOT of memory leaks over the years, esp in Python, .3% memory increase over that many iterations of something is difficult to say with certainty it is a leak.
Also, just looking at RSS makes it hard to analyze leaks. RSS can stay near constant while you are leaking. RSS is paged in mem, if you have a big leak your RSS might not increase much (leaked mem tends not to get used again so often paged out) while your virtual page allocation could be going through the roof...
yes, that's true, but unless the leak is big, I'm yet to find another measurement tool.
To prove your point here is a very simple IO in a loop program that also reads the same line all over again:
$ cat mmap-no-leak-debug.py
import gc
import mmap
import os
import psutil
import sys
proc = psutil.Process(os.getpid())
PATH = "./tmp.txt"
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
# create a large data file with a few long lines
if not os.path.exists(PATH):
with open(PATH, "w") as fh:
s = 'a'* 2**27 + "\n" # 128MB
# write ~2GB file
for i in range(16):
fh.write(s)
print(f"{'idx':>4} {'RSS':>10} {'Δ RSS':>12} {'Δ accumulated':>10}")
total_read = 0
content = ''
mem_after = mem_before_acc = mem_after_acc = mem_before = proc.memory_info().rss / 2**20
print(f"{0:4d} {mem_after:10.2f}MB {mem_after - 0:10.2f}MB {0:10.2f}MB")
mmap_mode = True if "--mmap" in sys.argv else False
with open(PATH, "r") as fh:
if mmap_mode:
mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)
idx = 0
while True:
idx += 1
mem_before = mem_read()
line = mm.readline() if mmap_mode else fh.readline()
if not line:
break
#total_read += len(line)
if "--accumulate" in sys.argv:
mem_before_acc = mem_read()
content += str(line)
mem_after_acc = mem_read()
mem_after = mem_read()
print(f"{idx:4d} {mem_after:10.2f}MB {mem_after - mem_before:10.2f}MB {mem_after_acc - mem_before_acc:10.2f}MB")
it has some other instrumentations to do mmap and accumulate data, but let's ignore that for now.
Here it is running in a simple non-mmap IO:
$ python mmap-no-leak-debug.py
idx RSS Δ RSS Δ accumulated
0 12.43MB 12.43MB 0.00MB
1 269.72MB 257.29MB 0.00MB
2 269.73MB 0.02MB 0.00MB
3 269.73MB 0.00MB 0.00MB
4 269.74MB 0.01MB 0.00MB
5 269.74MB 0.00MB 0.00MB
6 269.75MB 0.01MB 0.00MB
7 269.75MB 0.00MB 0.00MB
8 269.76MB 0.01MB 0.00MB
9 269.76MB 0.00MB 0.00MB
10 269.77MB 0.01MB 0.00MB
11 269.77MB 0.00MB 0.00MB
12 269.77MB 0.00MB 0.00MB
13 269.77MB 0.00MB 0.00MB
14 269.77MB 0.00MB 0.00MB
15 269.77MB 0.00MB 0.00MB
16 146.02MB -123.75MB 0.00MB
as you can see even this super-simplistic program that just performs readline()
slightly increases in RSS over iterations.
If you have a better tool for measurement other than RSS, I'm all ears.
@stas00 if you aren't using memory maps, you should be able to clearly see the increase in the virtual mem for the process as well. Even then, it could still be challenging to determine if it's leak vs fragmentation due to problematic allocation patterns (not uncommon with Python). Using a better mem allocator like tcmalloc via LD_PRELOAD hooks could reduce impact of fragmentation across both Python and c libs. Not sure that plays nice with any allocator that arrow might use itself though.
Thank you for these suggestions, Ross.
The problem is that most of the time we use a bunch of non-python libs that are binded to python and so besides python, one has to deal with not-quite controllable allocation strategies by those other components as well. So it's a super-tricky world.
Good suggestion on memory fragmentation, which could definitely be one of the sources for ever-growing RSS. pytorch's memory management utils are mostly quite excellent, and fragmentation is one of the main issues there. Projects like Deepspeed try to solve it by pre-allocating memory themselves and then managing it tightly to avoid fragmentation, which seems to work quite well.
BTW, I'm not sure if you have seen this tool I developed some years back to automatically track and report CPU and GPU memory usage in Jupyter notebooks. https://github.com/stas00/ipyexperiments I found it to be quite useful for detecting memory leakage - of course it's the same RSS for CPU, but it's just automated where each cell reports the delta. One other tricky thing to measure is CPU peak memory which it provides. As often there are those temp leaks which lead to OOMs.
OK, I ended up compiling the different stages of the research into an article, including a recommendation on how to remove the datasets
from interfering from memory leak debug:
https://gist.github.com/stas00/035d823106a896b1ba2032d34b97d541
@lhoestq, please have a look at your convenience and let's discuss how we use that in the datasets
docs. We can for example:
include it as a separate doc in the datasets
docs.
publish it as an HF blog post, link to it from the datasets
docs and perhaps only quote the last section that shows how to debug memory leaks while using datasets
of course I'm open to other options.
(I of course need to proof-read it, it surely could use an editing pass, I only checked the numbers made sense, but it should be quite readable already)
And I will paste the last section of the article here for posterity should the original disappear:
Therefore the easiest approach is to create a syntetic dataset of desired length with all records being the same. That way the data is no longer a factor in the memory usage patterns as it's always the same.
$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
records = 1_000_000
print("Creating a synthetic dataset")
row = dict(foo=[dict(a='a'*500, b='b'*1000)])
ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
ds.save_to_disk(DS_PATH)
print("Done. Please restart the program")
sys.exit()
dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 50_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
if idx == warmup_iterations: # skip the first few iterations while things get set up
mem_start = mem_read()
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()
print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")
We run this program once to create the dataset, and then the second time to profile its memory usage:
$ python ds-synthetic-no-mmap.py
Creating a synthetic dataset
Done. Please restart the program
$ python ds-synthetic-no-mmap.py
Dataset len=1000000
idx RSS Δ RSS
0 1649.6055MB 95.1992MB
50000 1728.4961MB 78.8906MB
100000 1728.7109MB 0.2148MB
150000 1729.2539MB 0.5430MB
200000 1729.0039MB -0.2500MB
250000 1729.5039MB 0.5000MB
300000 1729.2539MB -0.2500MB
350000 1729.7539MB 0.5000MB
400000 1729.5039MB -0.2500MB
450000 1730.0039MB 0.5000MB
500000 1729.7539MB -0.2500MB
550000 1730.2539MB 0.5000MB
600000 1730.0039MB -0.2500MB
650000 1730.5039MB 0.5000MB
700000 1730.2539MB -0.2500MB
750000 1730.7539MB 0.5000MB
800000 1730.5039MB -0.2500MB
850000 1731.0039MB 0.5000MB
900000 1730.7539MB -0.2500MB
950000 1731.2539MB 0.5000MB
Total diff: 2.0000MB (after 4 warmup iterations)
This is much better. There are still tiny fluctuations due to Python and you can see in the code I skipped the first few iterations in the code while things are being set up.
But otherwise now you can easily debug the rest of your code for any memory leaks since datasets
are in non-MMAP mode and the records size doesn't fluctuate.
Of course, do not forget to flip load_from_disk(..., keep_in_memory=True)
to False
when the debug is over so that you get back the performance speed up provided by MMAP.
Thanks ! Before thinking about the documentation, I'd like to make sure we can explain what happens in a correct and simple manner.
When accessing memory mapped data, the corresponding pages are loaded one at a time in the main memory and increase RSS.
When the pages are not used anymore, they are paged out otherwise it would fill up the physical RAM. However RSS doesn't decrease. If you read the entire file, then the RSS will end up bigger than the file, no matter how much physical RAM you have.
Is this correct ?
All but the last sentence I think.
It won't be able to read into memory more data than there is actual free host memory. As those are real paged in data. If there is no free memory it can't read read it in.
The cgroups limited shell just shows that that that paged in cached MMAPed data doesn't count towards process' actually used memory.
In other words if your host has 8GB of RAM and 0 SWAP, and all of that RAM is free and your process is unlimited it'll only be able to page in 8 GB of MMAPed data.
It's challenging to describe what happens in a simple manner because it's not so simple. It's system level behaviour that depends on the system wide configuration and all of the processes running. The OS virtual memory manager needs to manage the pages across all of the processes, incl the mmap'd files. Pages can stay in a while if there is little demand, or get paged out quickly if there is high demand (from possibly any of the other processes on the system). Random access patterns (ie need one entry here, one there, will require a new page in for each access, multiple pages if there is seq read-ahead enabled). If you have little system memory and are mapping large files, jumping around, you can easly end up thrashing (constantly paging in, and then ejecting, which results in disk IO for mmap'd files). So what you observe for a single process, what's resident, can change based on what other processes are doing as wel...
Many databases and applications that do heavy IO on large files opt to handle IO and buffer mangement themselves vs leave it up to the OS VMM. While mmap'd files are usually good enough, there are some pretty confusing edge case behaviours. Just loading the buffers directly into memory will be more consistent behaviour wise. However, that is only practical if the maximum file size is limited to a reasonable value (ie data is sharded into 256MB-2GB chunks).
Indeed, this is all very complex and there is no deterministic way or tracking real memory usage of a process as a whole, especially so when different unrelated components are intermixed - e.g. python and 3rd party c++ libraries bound to it. Each component's memory allocations can be traced separately, e.g. tracemalloc
for python, torch.cuda.memory*
for torch, mem_pool.bytes_allocated
for pyarrow, etc.
Instead of trying to explain all these complexities, I propose that all the lay user needs to know is this:
In my write up I have crafted a few set ups using code samples and cgroups-limited shells that demonstrate the above. So the above summary can be TLDR and the article can be the long story for those who care for an empirical demonstration.
If you have about 10min please kindly skim through my write up and let me know if you found some holes / unclear parts or whether this clears things up after reading it. Thank you!
I asked around and it seemed that my write up won't fit into the HF blog, so I published it on my blog https://stasosphere.com/entrepreneur-being/301-mmap-memory-leak-investigation/ I added context and edited it some more.
@lhoestq, I am not sure if you still want to include any of my notes in the datasets
docs or not. It's totally your call.
Ok :) I think the explanations in https://github.com/huggingface/datasets/issues/4883#issuecomment-1252722977 should be helpful for many users already. And maybe redirect to your blog for the details ? Then to debug how much Arrow data is physically in RAM we can also advise pa.total_allocated_bytes. What do you think ?
That sounds like a good plan, @lhoestq!
Wanted to add one method for dealing with issues in datasets/pyarrow that seem related to this thread (and i hope helps someone; I imagine they'll land here like i did). I was encoding image data on a fairly large dataset using datasets .map()
in a vanilla way and running into mysterious pyarrow errors or no errors at all and simply hanging forever. When run locally on one processor, I was getting the same results as this thread (memory seemed to increase forever). I tried every solution I could find with no luck.
I came up with a simple way to reduce the pressure on datasets/pyarrow during mapping/index flattening while still getting some of the benefit of parallel processing; this method was able to encode and save a 229gb dataset on a system with much less memory than that.
In my case the dataset starts as rows with file paths; during the .map()
phase, for each row (or training instance) the called function can open these file paths to perform encoding. Because of this, I can know the length of the dataset in advance.
.map()
batch_size/writer_batch_size can then be the step size divided by the number of processors with some sane minimum. For e.g. if you are dropping bad instances using batched=True
you need a sane minimum so a given processor doesnt result in 0 instances. So basically you are mapping on a chunk of the dataset and dividing that chunk among the processors..map()
with a batch_size/writer_batch_size that divides it among the processors.map()
phase, call save_to_disk
on it so it doesn't stay in memory or cause any weirdness with memory mapping.load_to_disk()
on them (with keep_in_memory=False of course) and append them to a listconcatenate_datasets()
on that list.This method is much slower than vanilla mapping, but enabled me to get past all the weird pyarrow issues and successfully encode a dataset that was much larger than the available memory on the instance it was running on.
Wanted to add one method for dealing with issues in datasets/pyarrow that seem related to this thread (and i hope helps someone; I imagine they'll land here like i did). I was encoding image data on a fairly large dataset using datasets
.map()
in a vanilla way and running into mysterious pyarrow errors or no errors at all and simply hanging forever. When run locally on one processor, I was getting the same results as this thread (memory seemed to increase forever). I tried every solution I could find with no luck.I came up with a simple way to reduce the pressure on datasets/pyarrow during mapping/index flattening while still getting some of the benefit of parallel processing; this method was able to encode and save a 229gb dataset on a system with much less memory than that.
In my case the dataset starts as rows with file paths; during the
.map()
phase, for each row (or training instance) the called function can open these file paths to perform encoding. Because of this, I can know the length of the dataset in advance.
- Using the dataset length and passed mapper batch size, find a divisor of the dataset length that divides it evenly (in my case, if it happens to be prime, just subtract one training example off the end)
- Using this divisor, iterate the dataset in chunks of this divisor size (call it a 'step' size).
- The actual
.map()
batch_size/writer_batch_size can then be the step size divided by the number of processors with some sane minimum. For e.g. if you are dropping bad instances usingbatched=True
you need a sane minimum so a given processor doesnt result in 0 instances. So basically you are mapping on a chunk of the dataset and dividing that chunk among the processors.- Start iterating on the dataset given the step size; pass that chunk to
.map()
with a batch_size/writer_batch_size that divides it among the processors- When a given chunk completes the
.map()
phase, callsave_to_disk
on it so it doesn't stay in memory or cause any weirdness with memory mapping.- Once all chunks have been encoded, iterate the directory where they were saved and call
load_to_disk()
on them (with keep_in_memory=False of course) and append them to a list- Call
concatenate_datasets()
on that list.This method is much slower than vanilla mapping, but enabled me to get past all the weird pyarrow issues and successfully encode a dataset that was much larger than the available memory on the instance it was running on.
plz give us the python code
I don't have time to generalize or really explain each step here, but maybe this will help someone. I recently used this technique on a dataset with ~5,000,000 sentences when straightforward dataset.map() failed in a way I've become familiar with. Note that the chosen config.batch_size
will have a huge impact on performance since each chunk is getting saved to disk. You want to pick one thats basically as large as possible knowing its split among x processors, but not so large it can hit the memory weirdness with straightforward dataset.map()'ing. For e.g. on that 5 mil sentence dataset, a batch size of 50,000 worked very well.
dataset_len = len(dataset)
step = None
for num in range(config.batch_size, dataset_len):
if dataset_len % num == 0:
step = num
break
if not step:
dataset = dataset.select(range(0, dataset_len - 1))
dataset_len = dataset_len - 1
for num in range(config.batch_size, dataset_len):
if dataset_len % num == 0:
step = num
break
if not step:
raise Exception(f"Could not find a step size for {dataset_len}")
batch_size = int(step/num_processes) if int(step/num_processes) >= 200 else 200
print(
f'Initiating encoding via concatenating datasets. num_processes={num_processes}, keep_in_memory={config.keep_in_memory}, dataset_len={dataset_len}, step={step}, batch_size={batch_size}')
for idx in range(0, dataset_len, step):
print(f"Operating on {idx} -> {idx+step} of {dataset_len}")
chunk = dataset.select(range(idx, idx+step))
encoded_chunk = chunk.map(
preprocess_data,
fn_kwargs=dict(tokenizer=tokenizer, labels=labels, label2id=label2id, config=config,
batch_size=batch_size),
keep_in_memory=config.keep_in_memory,
batched=True,
features=features,
remove_columns=dataset.column_names,
num_proc=num_processes,
batch_size=batch_size,
writer_batch_size=batch_size
)
print(f"Saving {idx} -> {idx+step} chunk to disk")
encoded_chunk.save_to_disk(f"{SAGEMAKER_CHUNK_DIR}/chunk_{idx}_{idx+step}")
print("Loading encoded chunks for concatenating")
encoded_chunks = []
parent_path, subdirs, parent_files = next(os.walk(SAGEMAKER_CHUNK_DIR))
for chunk in subdirs:
print(f"Loading chunk: {chunk}")
curr_path = os.path.join(parent_path, chunk)
encoded_chunks.append(load_from_disk(curr_path))
print(f"Concatenating {len(encoded_chunks)} datasets")
encoded_dataset = concatenate_datasets(encoded_chunks)
I ended up splitting my DatasetDicts
into N pieces
split_data = False
split_datadict = True
N = 50
total_rows = len(gtzan)
split_size = total_rows // N
remaining = total_rows % N # To handle any remainder
from datasets import DatasetDict
if split_datadict:
# Initialize a list to hold the new DatasetDicts
split_dataset_dicts = [DatasetDict() for _ in range(N)]
# Process each dataset in the DatasetDict
for dataset_name, dataset in gtzan.items():
total_rows = len(dataset)
split_size = total_rows // N
remaining = total_rows % N
# Create the splits for this dataset
start_idx = 0
for i in range(N):
end_idx = start_idx + split_size + (1 if i < remaining else 0)
split = dataset.select(range(start_idx, end_idx))
# Add this split to the corresponding new DatasetDict
split_dataset_dicts[i][dataset_name] = split
start_idx = end_idx
then run my function
# Initialize an empty list for the processed splits
processed_splits = []
# Loop through each split and apply the preprocessing
for cnt, split in enumerate(split_dataset_dicts):
print(f"iteraction:{cnt}")
processed_split = split.map(preprocess_function,
remove_columns=["audio"],
batched = True,
num_proc=4,
)
processed_splits.append(processed_split)
interestingly enough list comprehension does not work here and hang too ! so go old/good fashion loop.
then concat the result
from datasets import concatenate_datasets, DatasetDict
# Assuming split_dataset_dicts is your list of split DatasetDicts
# Initialize a dictionary to hold the concatenated datasets
concatenated_datasets = {}
# Get all dataset names (keys) from the first split (assuming all splits have the same structure)
dataset_names = split_dataset_dicts[0].keys()
# Concatenate the datasets for each name
for name in dataset_names:
# Collect the same part of each split
datasets_to_concatenate = [split[name] for split in split_dataset_dicts]
# Concatenate them
concatenated_dataset = concatenate_datasets(datasets_to_concatenate)
# Add to the dictionary
concatenated_datasets[name] = concatenated_dataset
# Create a new DatasetDict with the concatenated datasets
concatenated_dataset_dict = DatasetDict(concatenated_datasets)
Describe the bug
When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.
Steps to reproduce the bug
Run and observe the output of this snippet which logs RSS memory.
Expected results
Memory should not increase after initial setup and loading of the dataset
Actual results
Memory continuously increases as can be seen in the log.
Environment info
datasets
version: 2.3.2