huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.99k stars 2.62k forks source link

With dataloader RSS memory consumed by HF datasets monotonically increases #4883

Open apsdehal opened 2 years ago

apsdehal commented 2 years ago

Describe the bug

When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.

Steps to reproduce the bug

Run and observe the output of this snippet which logs RSS memory.

import psutil
import os
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_TRIES = 10

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
    x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
    x.pop("text")
    x.pop("label")
    return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
    for idx, batch in enumerate(train_loader):
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(count, idx, mem_after - mem_before)
    count += 1

Expected results

Memory should not increase after initial setup and loading of the dataset

Actual results

Memory continuously increases as can be seen in the log.

Environment info

stas00 commented 2 years ago

Are you sure there is a leak? How can I see it? You shared the script but not the output which you believe should indicate a leak.

I modified your reproduction script to print only once per try as your original was printing too much info and you absolutely must add gc.collect() when doing any memory measurements, since python's GC is scheduled so you might be measuring the wrong thing. This gives us:

import psutil
import os
import gc
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_TRIES = 100

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
    x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
    x.pop("text")
    x.pop("label")
    return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)

count = 0
while count < NUM_TRIES:
    for idx, batch in enumerate(train_loader): pass
    gc.collect()
    mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    print(count, mem_after - mem_before)
    count += 1

Now running it:

$ python dl-leak.py 
Reusing dataset imdb (/home/stas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
0 4.43359375
1 4.4453125
2 4.44921875
3 4.44921875
4 4.4609375
5 4.46484375
6 4.46484375
7 4.46484375
8 4.46484375
9 4.46484375
10 4.46484375
11 4.46484375
12 4.46484375
13 4.46484375
14 4.46484375
15 4.46484375
16 4.46484375

It's normal that at the beginning there is a small growth in memory usage, but after 5 cycles it gets steady.

stas00 commented 2 years ago

Unless of course you're referring the memory growth during the first try. Is that what you're referring to? And since your ds is small it's hard to see the growth - could it be just because some records are longer and it needs to allocate more memory for those?

Though while experimenting with this I have observed a peculiar thing, if I concatenate 2 datasets, I don't see any growth at all. But that's probably because the program allocated additional peak RSS memory to concatenate and then is re-using the memory

I basically tried to see if I make the dataset much longer, I'd expect not to see any memory growth once the 780 records of the imdb ds have been processed once.

apsdehal commented 2 years ago

It is hard to say if it is directly reproducible in this setup. Maybe it is specific to the images stored in the CM4 case which cause a memory leak. I am still running your script and seeing if I can reproduce that particular leak in this case.

stas00 commented 2 years ago

I was able to reproduce the leak with:

import psutil
import os
import gc
from datasets import load_from_disk
import time

DATASET_PATH = "/hf/m4-master/data/cm4/cm4-10000-v0.1"

dataset = load_from_disk(DATASET_PATH)

# truncate to a tiny dataset
dataset = dataset.select(range(1000))

print(f"dataset: {len(dataset)} records")

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, rec in enumerate(dataset):
    if idx % 100 == 0:
        gc.collect()
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

You need to adjust the DATASET_PATH record.

which you get from

gsutil -m cp   "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset.arrow"   "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset_info.json"   "gs://hf-science-m4/cm4/cm4-10000-v0.1/state.json"   .

(I assume the hf folks have the perms) - it's a smallish dataset (10k)

then you run:

$ python ds.py
dataset: 1000 records
   0       1.0156MB
 100     126.3906MB
 200     142.8906MB
 300     168.5586MB
 400     218.3867MB
 500     230.7070MB
 600     238.9570MB
 700     263.3789MB
 800     288.1289MB
 900     300.5039MB

you should be able to see the leak

stas00 commented 2 years ago

This issue has nothing to do with PIL's decoder. I removed it and the problem is still there.

I then traced this leak to this single call: pa_table.to_pydict() here:

https://github.com/huggingface/datasets/blob/08a7b389cdd6fb49264a72aa8ccfc49a233494b6/src/datasets/formatting/formatting.py#L138-L140

I can make it leak much faster by modifying that code to repeat pa_table.to_pydict() many times in a row. It shouldn't have that impact:

class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]):
    def extract_row(self, pa_table: pa.Table) -> dict:
        x = [pa_table.to_pydict() for x in range(200)]
        return _unnest(pa_table.to_pydict())

@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?

Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.

The problem already happens with pyarrow==6.0.0 or later (minimum for current datasets)

I'm also trying to dig in with objgraph to see if there are any circular references which prevent objects from being freed, but no luck there so far. And I'm pretty sure to_pydict is not a python code, so the problem is likely to happen somewhere outside of python's GC.

stas00 commented 2 years ago

This appears to be the same issue I think: https://github.com/huggingface/datasets/issues/4528 I dug into the repro code there and it's the same behavior with the same leak, but it's a pure nlp dataset and thus much faster to work with.

stas00 commented 2 years ago

I went all the way back to pyarrow==1.0.0 and datasets==1.12.0 and the problem is still there. How is it even possible that it wasn't noticed all this time.

Could it be that the leak is in some 3rd party component pyarrow relies on? as while downgrading I have only downgraded the above 2 packages.

stas00 commented 2 years ago

Also found this warning

    Be careful: if you don't pass the ArrowArray struct to a consumer,
    array memory will leak.  This is a low-level function intended for
    expert users.

see: https://github.com/apache/arrow/blob/99b57e84277f24e8ec1ddadbb11ef8b4f43c8c89/python/pyarrow/table.pxi#L2515-L2517

perhaps something triggers this condition?

I have no idea if it's related - this is just something that came up during my research.

rwightman commented 2 years ago

Does it crash with OOM at some point? If it doesn't, it isn't a leak, just agressive caching or a custom allocator that doesn't like to give memory back (not uncommon). #4528 looks like it hits a steady state.

I believe the underlying arrow libs use a custom C allocator. Some of those are designed not to give back to OS, but keep heap memory for themselves to re-use (hitting up the OS involves more expensive mutex locks, contention, etc). The greedy behaviour can be undesirable though. There are likely flags to change the allocator behaviour, and one could likely build without any custom allocators (or use a different one).

SaulLu commented 2 years ago

Does it crash with OOM at some point?

In the original setup where we noticed this problem, it was indeed ending in an OOM

NouamaneTazi commented 2 years ago

https://github.com/huggingface/datasets/issues/4528 looks like it hits a steady state.

@rwightman in the plot I shared, the steady state comes from the time.sleep(100) I added in the end of the script, to showcase that even the garbage collector couldn't free that allocated memory.

SaulLu commented 2 years ago

Could this be related to this discussion about a potential memory leak in pyarrow: https://issues.apache.org/jira/browse/ARROW-11007 ?

(Note: I've tried import pyarrow; pyarrow.jemalloc_set_decay_ms(0) and the memory leak is still happening on your toy example)

lhoestq commented 2 years ago

@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?

to_pydict calls to_pylist on each column (i.e. on each PyArrow Array). Then it iterates on the array and calls as_py on each element. The as_py implementation depends on the data type. For strings I think it simply gets the buffer that contains the binary string data that is defined in C++

The Arrow team is pretty responsive at user@arrow.apache.org if it can help

Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.

That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?

VictorSanh commented 2 years ago

That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?

I added you to the bucket @lhoestq

lhoestq commented 2 years ago

It looks like an issue with memory mapping:

lhoestq commented 2 years ago

Here is a code to reproduce this issue using only PyArrow and a dummy arrow file:

import psutil
import os
import gc
import pyarrow as pa
import time

ARROW_PATH = "tmp.arrow"

if not os.path.exists(ARROW_PATH):
    arr = pa.array([b"a" * (200 * 1024)] * 1000)  # ~200MB
    table = pa.table({"a": arr})

    with open(ARROW_PATH, "wb") as f:
        writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
        writer.write_table(table)
        writer.close()

def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
    memory_mapped_stream = pa.memory_map(filename)
    opened_stream = pa.ipc.open_stream(memory_mapped_stream)
    pa_table = opened_stream.read_all()
    return pa_table

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

prints

   0       0.2500MB
 100      19.8008MB
 200      39.3320MB
 300      58.8633MB
 400      78.3945MB
 500      97.9258MB
 600     117.4570MB
 700     136.9883MB
 800     156.5195MB
 900     176.0508MB

Note that this example simply iterates over the pyarrow.lib.BinaryScalar objects in the array. Running .as_py() is not needed to experience the memory issue.

rwightman commented 2 years ago

@lhoestq that does indeed increase in memory, but if you iterate over array again after the first time, or re-open and remap the same file (repeat table = memory_mapped_arrow_table_from_file(ARROW_PATH)) before re-iterating, it doesn't move pas 195MB.... it would appear another step is needed to continue consuming memory past that.. hmmm

Are the pa_tables held on to anywhere after they are iterated in the real code?

in my hack, if you do a bunch cut & paste and then change the arr name for each iter

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]

for idx, x in enumerate(arr):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr1 = table[0]

for idx, x in enumerate(arr1):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr2 = table[0]

for idx, x in enumerate(arr2):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

it leaks, if all arr are the same name (so prev one gets cleaned up) it does not and goes back to 0, anything that could be holding onto a reference of an intermediary equivalent like arr in the real use case?

stas00 commented 2 years ago

Yes, we have already established here https://github.com/huggingface/datasets/issues/4883#issuecomment-1232063891 that when one iterates over the whole dataset multiple times, it consumes a bit more memory in the next few repetitions and then remains steady.

Which means that when a new iterator is created over the same dataset, all the memory from the previous iterator is re-used.

So the leak happens primarily when the iterator is "drained" the first time. which tells me that either a circular reference is created somewhere which only gets released when the iterator is destroyed, or there is some global variable that keeps piling up the memory and doesn't release it in time.

Also I noticed some __del__ methods which won't destroy objects automatically and there is usually a warning against using it https://stackoverflow.com/a/1481512/9201239

There are also some weakrefs in the code which too may lead to leaks or weird problems at times.

rwightman commented 2 years ago

@stas00 my point was, I'm not convinced @lhoestq last example illustrates the leak, but rather the differences between memory mapping and in memory usage patterns. If you destroy arr, memory map impl goes back to 0 each iteration. The amount of memory that 'looks' like it is leaked in first pass differes quite a bit between memory mapped vs in memory, but the underlying issue likely a circular reference, or reference(s) which were not cleaned up that would impact either case, but likely much more visible with mmap.

stas00 commented 2 years ago

Thank you for clarifying, Ross.

I think we agree that it's almost certain that the datasets iterator traps some inner variable that prevents object freeing, since if we create the iterator multiple times (and drain it) after a few runs no new memory is allocated. We could try to dig in more with objgraph - my main concern is if the problem happens somewhere outside of python, (i.e. in pyarrow cpp implementation) in which case it'd be much more difficult to trace.

I wish there was a way on linux to tell the program to free no longer used memory at will.

rwightman commented 2 years ago

FWIW, I revisted some code I had in the works to use HF datasets w/ timm train & val scripts. There is no leak there across multipe epochs. It uses the defaults.

It's worth noting that with imagenet keep_in_memory=True isn't even an option because the train arrow file is ~140GB and my local memory is less. The virtual address space reflects mmap (> 150GB) and doesn't increase over epochs that I noticed. I have some perf issues to bring up wrt to the current setup, but that's a separate and lower prio discussion to have elsewhere...

stas00 commented 2 years ago

Notes

After reading many issues and trying many things here is the summary of my learning

I'm now using @lhoestq repro case as it's pyarrow-isolated: https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985

1. pyarrow memory backends

it has 3 backends, I tried them all with the same results

pa.set_memory_pool(pa.jemalloc_memory_pool())
pa.set_memory_pool(pa.mimalloc_memory_pool())
pa.set_memory_pool(pa.system_memory_pool())

2. quick release

The jemalloc backend supports quick release

pa.jemalloc_set_decay_ms(0)

it doesn't make any difference in this case

3. actual memory allocations

this is a useful tracer for PA memory allocators

pa.log_memory_allocations(enable=True)

it nicely reports memory allocations and releases when the arrow file is created the first time.

but when we then try to do enumerate(arr) this logger reports 0 allocations.

This summary also reports no allocations when the script run the second time (post file creation):

mem_pool = pa.default_memory_pool()
print(f"PyArrow mem pool info: {mem_pool.backend_name} backend, {mem_pool.bytes_allocated()} allocated, "
              f"{mem_pool.max_memory()} max allocated, ")

print(f"PyArrow total allocated bytes: {pa.total_allocated_bytes()}")

However it's easy to see by using tracemalloc which only measures python's memory allocations that it's PA that leaks, since tracemalloc shows fixed memory

(this is bolted on top of the original repro script)

import tracemalloc
tracemalloc.start()

[...]
for idx, x in enumerate(arr):
    if idx % 10 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
        mem_use = pa.total_allocated_bytes() - start_use
        mem_peak = pool.max_memory() - start_peak_use

        second_size, second_peak = tracemalloc.get_traced_memory()
        mem_diff = (second_size - first_size) / 2**20
        mem_peak_diff = (second_peak - first_peak) / 2**20

        # pa.jemalloc_memory_pool().release_unused()
        # pa.mimalloc_memory_pool().release_unused()
        # pa.system_memory_pool().release_unused()

        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {memory_mapped_stream.size()/2**20:4.4}MB {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")

gives:

   0       5.4258MB       0.0110       0.0201 195.3MB  0.0MB  0.0MB
  10      25.3672MB       0.0112       0.0202 195.3MB  0.0MB  0.0MB
  20      45.9336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  30      62.4336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  40      83.0586MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  50     103.6836MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  60     124.3086MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  70     140.8086MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  80     161.4336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  90     182.0586MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB

the 3rd and 4th columns are tracemalloc's report.

the 5th column is the size of mmaped stream - fixed.

the last 2 are the PA's malloc reports - you can see it's totally fixed and 0.

So what gives? PA's memory allocator says nothing was allocated and we can see python doesn't allocate any memory either.

As someone suggested in one of the PA issues that IPC/GRPC could be the issue. Any suggestions on how debug this one?

The main issue is that one can't step through with a python debugger as arr is an opaque cpp object binded to python.

Please see the next comment for a possible answer.

ref-count

I also traced reference counts and they are all fixed using either sys.getrefcount(x) or len(gc.get_referrers(x))

so it's not the python object

Important related discussions

https://issues.apache.org/jira/browse/ARROW-11007 - looks very similar to our issue in particular this part of the report: https://issues.apache.org/jira/browse/ARROW-11007?focusedCommentId=17279642&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17279642

stas00 commented 2 years ago

There is no leak, just badly communicated linux RSS memory usage stats

Next, lets revisit @rwightman's suggestion that there is actually no leak.

After all - we are using mmap which will try to map the file to RAM as much as it can and then page out if there is no memory. i.e. MMAP is only fast if you have a lot of CPU RAM.

So let's do it:

Memory mapping OOM test

We first quickly start a cgroups-controlled shell which will instantly kill any program that consumes more than 1GB of memory:

$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=1G --setenv="MEMLIMIT=1GB" bash

Let's check that it indeed does so. Let's change @lhoestq's script to allocate a 10GB arrow file:

$ python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 5000)'
Killed

oops, that didn't work, as we tried to allocate 10GB when only 1GB is allowed. This is what we want!

Let's do a sanity check - can we allocate 0.1GB?

python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 50)'

Yes. So the limited shell does the right thing. It let's allocate < 1GB of RSS RAM.

Next let's go back to @lhoestq's script but with 10GB arrow file.

we change his repro script https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985 to 50x larger file

    arr = pa.array([b"a" * (2000 * 1024)] * 5000)  # ~10000MB

we first have to run into a normal unlimited shell so that we don't get killed (as the script allocates 10GB)

let's run the script now in the 1GB-limited shell while running a monitor:

$ htop -F python -s M_RESIDENT -u `whoami`

so we have 2 sources of RSS info just in case.

$ python pyar.py
   0       4.3516MB       0.0103       0.0194 9.766e+03MB  0.0MB  0.0MB
  10      24.3008MB       0.0104       0.0195 9.766e+03MB  0.0MB  0.0MB
[...]
4980    9730.3672MB       0.0108       0.0199 9.766e+03MB  0.0MB  0.0MB
4990    9750.9922MB       0.0108       0.0199 9.766e+03MB  0.0MB  0.0MB
PyArrow mem pool info: jemalloc backend, 0 allocated, 0 max allocated,
PyArrow total allocated bytes: 0

But wait, it reported 10GB RSS both in htop and in our log!

So that means it never allocated 10GB otherwise it'd have been killed.

Which tells us that there is no leak whatsoever and this is just a really difficult situation where MMAPPED memory is reported as part of RSS which it probably shouldn't. As now we have no way how to measure real memory usage.

I also attached the script with all the different things I have tried in it, so it should be easy to turn them on/off if you want to reproduce any of my findings.

pyar.txt

just rename it to pyra.py as gh doesn't let attaching scripts...

(I have to remember to exit that special mem-limited shell or else I won't be able to do anything serious there.)

stas00 commented 2 years ago

The original leak in the multi-modal code is very likely something else. But of course now it'd be very difficult to trace it using mmap.

I think to debug we have to set keep_in_memory=True in load_from_disk to load the small dataset in RAM, so there will be no mmap misleading reporting component and then continue searching for another source of a leak.

NouamaneTazi commented 2 years ago

To add to what @stas00 found, I'm gonna leave some links to where I believe the confusion came from in pyarrow's APIs, for future reference:

Arrow can directly reference the data mapped from disk and avoid having to allocate its own memory.

And where their example shows 0 RSS memory allocation, the way we used to measure RSS shows 39.6719MB allocated. Here's the script to reproduce:

import psutil
import os
import gc
import pyarrow as pa
import time
import sys

# gc.set_debug(gc.DEBUG_LEAK)
# gc.set_threshold(0,0,0)

#pa.set_memory_pool(pa.mimalloc_memory_pool())
#pa.set_memory_pool(pa.system_memory_pool())

import tracemalloc

#pa.jemalloc_set_decay_ms(0)
# pa.log_memory_allocations(enable=True)

BATCH_SIZE = 10000
NUM_BATCHES = 1000
schema = pa.schema([pa.field('nums', pa.int32())])
with pa.OSFile('bigfile.arrow', 'wb') as sink:
   with pa.ipc.new_file(sink, schema) as writer:
      for row in range(NUM_BATCHES):
            batch = pa.record_batch([pa.array(range(BATCH_SIZE), type=pa.int32())], schema)
            writer.write(batch)

start_use = pa.total_allocated_bytes()
pool = pa.default_memory_pool()
start_peak_use = pool.max_memory()
tracemalloc.start()
first_size, first_peak = tracemalloc.get_traced_memory()
mem_before = psutil.Process(os.getpid()).memory_info().rss / 2**20

# with pa.OSFile('bigfile.arrow', 'rb') as source:
#    loaded_array = pa.ipc.open_file(source).read_all()

with pa.memory_map('bigfile.arrow', 'rb') as source:
   loaded_array = pa.ipc.open_file(source).read_all()

print("LEN:", len(loaded_array))
print("RSS: {}MB".format(pa.total_allocated_bytes() >> 20))

gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20

idx = 0
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")

gives:


LEN: 10000000
RSS: 0MB
   0      39.6719MB       0.0132       0.0529  0.0MB  0.0MB

Which again just proves that we uncorrectly measure RSS, in the case of MMAPPED memory

stas00 commented 1 year ago

@lhoestq, I have been working on a detailed article that shows that MMAP doesn't leak and it's mostly ready. I will share when it's ready.

The issue is that we still need to be able to debug memory leaks by turning MMAP off.

But, once I tried to show the user that using load_dataset(... keep_in_memory=True) is the way to debug an actual memory leak - guess I what I discovered? A potential actual leak.

Here is the repro:

$ cat ds-mmap.py
from datasets import load_dataset
import gc
import os
import psutil

proc = psutil.Process(os.getpid())
def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']

print(f"{'idx':>6} {'RSS':>10} {'Δ RSS':>15}")
step = 20000
for i in range(0, 10*step, step):
    mem_before = mem_read()
    _ = dataset[i:i+step]
    mem_after = mem_read()
    print(f"{i:6d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB ")
python ds-io.py
Reusing dataset wmt19 (/home/stas/.cache/huggingface/datasets/wmt19/cs-en/1.0.0/c3db1bf4240362ed1ef4673b354f468d70aac66d4e67d45f536d493a0840f0d3)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.66it/s]
   idx        RSS           Δ RSS
     0    1398.4609MB       3.5195MB
 20000    1398.5742MB       0.1133MB
 40000    1398.6016MB       0.0273MB
 60000    1398.6016MB       0.0000MB
 80000    1398.6016MB       0.0000MB
100000    1398.6328MB       0.0312MB
120000    1398.6953MB       0.0625MB
140000    1398.6953MB       0.0000MB
160000    1398.7500MB       0.0547MB
180000    1398.7500MB       0.0000MB
stas00 commented 1 year ago

as I suggested on slack perhaps it was due to dataset records length variation, so with your help I wrote another repro with synthetic records which are all identical - which should remove my hypothese from the equation and we should expect 0 incremental growth as we iterate over the datasets. But alas this is not the case. There is a tiny but definite leak-like behavior.

Here is the new repro:

$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil

proc = psutil.Process(os.getpid())
def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
    records = 1_000_000
    print("Creating a synthetic dataset")
    row = dict(foo=[dict(a='a'*500, b='b'*1000)])
    ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
    ds.save_to_disk(DS_PATH)
    print("Done. Please restart the program")
    sys.exit()

dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")

print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 25_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
    if idx == warmup_iterations: # skip the first few iterations while things get set up
        mem_start = mem_read()
    mem_before = mem_read()
    _ = dataset[i:i+step]
    mem_after = mem_read()
    print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()

print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")

and the run:

$ python ds-synthetic-no-mmap.py
Dataset len=1000000
     idx        RSS           Δ RSS
       0    1601.9258MB      47.9688MB
   25000    1641.6289MB      39.7031MB
   50000    1641.8594MB       0.2305MB
   75000    1642.1289MB       0.2695MB
  100000    1642.1289MB       0.0000MB
  125000    1642.3789MB       0.2500MB
  150000    1642.3789MB       0.0000MB
  175000    1642.6289MB       0.2500MB
  200000    1642.6289MB       0.0000MB
  225000    1642.8789MB       0.2500MB
  250000    1642.8828MB       0.0039MB
  275000    1643.1328MB       0.2500MB
  300000    1643.1328MB       0.0000MB
  325000    1643.3828MB       0.2500MB
  350000    1643.3828MB       0.0000MB
  375000    1643.6328MB       0.2500MB
  400000    1643.6328MB       0.0000MB
  425000    1643.8828MB       0.2500MB
  450000    1643.8828MB       0.0000MB
  475000    1644.1328MB       0.2500MB
  500000    1644.1328MB       0.0000MB
  525000    1644.3828MB       0.2500MB
  550000    1644.3828MB       0.0000MB
  575000    1644.6328MB       0.2500MB
  600000    1644.6328MB       0.0000MB
  625000    1644.8828MB       0.2500MB
  650000    1644.8828MB       0.0000MB
  675000    1645.1328MB       0.2500MB
  700000    1645.1328MB       0.0000MB
  725000    1645.3828MB       0.2500MB
  750000    1645.3828MB       0.0000MB
  775000    1645.6328MB       0.2500MB
  800000    1645.6328MB       0.0000MB
  825000    1645.8828MB       0.2500MB
  850000    1645.8828MB       0.0000MB
  875000    1646.1328MB       0.2500MB
  900000    1646.1328MB       0.0000MB
  925000    1646.3828MB       0.2500MB
  950000    1646.3828MB       0.0000MB
  975000    1646.6328MB       0.2500MB
Total diff:       4.5039MB (after 4 warmup iterations)

so I'm still not sure why we get this.

As you can see I started skipping the first few iterations where memory isn't stable yet. As the actual diff is much larger if we count all iterations.

What do you think?

rwightman commented 1 year ago

@stas00 my 2 cents from having looked at a LOT of memory leaks over the years, esp in Python, .3% memory increase over that many iterations of something is difficult to say with certainty it is a leak.

Also, just looking at RSS makes it hard to analyze leaks. RSS can stay near constant while you are leaking. RSS is paged in mem, if you have a big leak your RSS might not increase much (leaked mem tends not to get used again so often paged out) while your virtual page allocation could be going through the roof...

stas00 commented 1 year ago

yes, that's true, but unless the leak is big, I'm yet to find another measurement tool.

To prove your point here is a very simple IO in a loop program that also reads the same line all over again:

$ cat mmap-no-leak-debug.py
import gc
import mmap
import os
import psutil
import sys

proc = psutil.Process(os.getpid())

PATH = "./tmp.txt"

def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

# create a large data file with a few long lines
if not os.path.exists(PATH):
    with open(PATH, "w") as fh:
        s = 'a'* 2**27 + "\n" # 128MB
        # write ~2GB file
        for i in range(16):
            fh.write(s)

print(f"{'idx':>4} {'RSS':>10}   {'Δ RSS':>12}   {'Δ accumulated':>10}")

total_read = 0
content = ''
mem_after = mem_before_acc = mem_after_acc = mem_before = proc.memory_info().rss / 2**20
print(f"{0:4d} {mem_after:10.2f}MB {mem_after - 0:10.2f}MB {0:10.2f}MB")

mmap_mode = True if "--mmap" in sys.argv else False

with open(PATH, "r") as fh:

    if mmap_mode:
        mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)

    idx = 0
    while True:
        idx += 1
        mem_before = mem_read()
        line = mm.readline() if mmap_mode else fh.readline()
        if not line:
            break

        #total_read += len(line)

        if "--accumulate" in sys.argv:
            mem_before_acc = mem_read()
            content += str(line)
            mem_after_acc = mem_read()

        mem_after = mem_read()

        print(f"{idx:4d} {mem_after:10.2f}MB {mem_after - mem_before:10.2f}MB {mem_after_acc - mem_before_acc:10.2f}MB")

it has some other instrumentations to do mmap and accumulate data, but let's ignore that for now.

Here it is running in a simple non-mmap IO:

$ python mmap-no-leak-debug.py
 idx        RSS          Δ RSS   Δ accumulated
   0      12.43MB      12.43MB       0.00MB
   1     269.72MB     257.29MB       0.00MB
   2     269.73MB       0.02MB       0.00MB
   3     269.73MB       0.00MB       0.00MB
   4     269.74MB       0.01MB       0.00MB
   5     269.74MB       0.00MB       0.00MB
   6     269.75MB       0.01MB       0.00MB
   7     269.75MB       0.00MB       0.00MB
   8     269.76MB       0.01MB       0.00MB
   9     269.76MB       0.00MB       0.00MB
  10     269.77MB       0.01MB       0.00MB
  11     269.77MB       0.00MB       0.00MB
  12     269.77MB       0.00MB       0.00MB
  13     269.77MB       0.00MB       0.00MB
  14     269.77MB       0.00MB       0.00MB
  15     269.77MB       0.00MB       0.00MB
  16     146.02MB    -123.75MB       0.00MB

as you can see even this super-simplistic program that just performs readline() slightly increases in RSS over iterations.

If you have a better tool for measurement other than RSS, I'm all ears.

rwightman commented 1 year ago

@stas00 if you aren't using memory maps, you should be able to clearly see the increase in the virtual mem for the process as well. Even then, it could still be challenging to determine if it's leak vs fragmentation due to problematic allocation patterns (not uncommon with Python). Using a better mem allocator like tcmalloc via LD_PRELOAD hooks could reduce impact of fragmentation across both Python and c libs. Not sure that plays nice with any allocator that arrow might use itself though.

stas00 commented 1 year ago

Thank you for these suggestions, Ross.

The problem is that most of the time we use a bunch of non-python libs that are binded to python and so besides python, one has to deal with not-quite controllable allocation strategies by those other components as well. So it's a super-tricky world.

Good suggestion on memory fragmentation, which could definitely be one of the sources for ever-growing RSS. pytorch's memory management utils are mostly quite excellent, and fragmentation is one of the main issues there. Projects like Deepspeed try to solve it by pre-allocating memory themselves and then managing it tightly to avoid fragmentation, which seems to work quite well.

BTW, I'm not sure if you have seen this tool I developed some years back to automatically track and report CPU and GPU memory usage in Jupyter notebooks. https://github.com/stas00/ipyexperiments I found it to be quite useful for detecting memory leakage - of course it's the same RSS for CPU, but it's just automated where each cell reports the delta. One other tricky thing to measure is CPU peak memory which it provides. As often there are those temp leaks which lead to OOMs.

stas00 commented 1 year ago

OK, I ended up compiling the different stages of the research into an article, including a recommendation on how to remove the datasets from interfering from memory leak debug: https://gist.github.com/stas00/035d823106a896b1ba2032d34b97d541

@lhoestq, please have a look at your convenience and let's discuss how we use that in the datasets docs. We can for example:

  1. include it as a separate doc in the datasets docs.

  2. publish it as an HF blog post, link to it from the datasets docs and perhaps only quote the last section that shows how to debug memory leaks while using datasets

of course I'm open to other options.

(I of course need to proof-read it, it surely could use an editing pass, I only checked the numbers made sense, but it should be quite readable already)

stas00 commented 1 year ago

And I will paste the last section of the article here for posterity should the original disappear:

Using synthetic MMAP-disabled dataset to debug memory leaks

Therefore the easiest approach is to create a syntetic dataset of desired length with all records being the same. That way the data is no longer a factor in the memory usage patterns as it's always the same.

$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil

proc = psutil.Process(os.getpid())
def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
    records = 1_000_000
    print("Creating a synthetic dataset")
    row = dict(foo=[dict(a='a'*500, b='b'*1000)])
    ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
    ds.save_to_disk(DS_PATH)
    print("Done. Please restart the program")
    sys.exit()

dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")

print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 50_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
    if idx == warmup_iterations: # skip the first few iterations while things get set up
        mem_start = mem_read()
    mem_before = mem_read()
    _ = dataset[i:i+step]
    mem_after = mem_read()
    print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()

print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")

We run this program once to create the dataset, and then the second time to profile its memory usage:

$ python ds-synthetic-no-mmap.py
Creating a synthetic dataset
Done. Please restart the program
$ python ds-synthetic-no-mmap.py
Dataset len=1000000
     idx        RSS           Δ RSS
       0    1649.6055MB      95.1992MB
   50000    1728.4961MB      78.8906MB
  100000    1728.7109MB       0.2148MB
  150000    1729.2539MB       0.5430MB
  200000    1729.0039MB      -0.2500MB
  250000    1729.5039MB       0.5000MB
  300000    1729.2539MB      -0.2500MB
  350000    1729.7539MB       0.5000MB
  400000    1729.5039MB      -0.2500MB
  450000    1730.0039MB       0.5000MB
  500000    1729.7539MB      -0.2500MB
  550000    1730.2539MB       0.5000MB
  600000    1730.0039MB      -0.2500MB
  650000    1730.5039MB       0.5000MB
  700000    1730.2539MB      -0.2500MB
  750000    1730.7539MB       0.5000MB
  800000    1730.5039MB      -0.2500MB
  850000    1731.0039MB       0.5000MB
  900000    1730.7539MB      -0.2500MB
  950000    1731.2539MB       0.5000MB
Total diff:       2.0000MB (after 4 warmup iterations)

This is much better. There are still tiny fluctuations due to Python and you can see in the code I skipped the first few iterations in the code while things are being set up.

But otherwise now you can easily debug the rest of your code for any memory leaks since datasets are in non-MMAP mode and the records size doesn't fluctuate.

Of course, do not forget to flip load_from_disk(..., keep_in_memory=True) to False when the debug is over so that you get back the performance speed up provided by MMAP.

lhoestq commented 1 year ago

Thanks ! Before thinking about the documentation, I'd like to make sure we can explain what happens in a correct and simple manner.

When accessing memory mapped data, the corresponding pages are loaded one at a time in the main memory and increase RSS.

When the pages are not used anymore, they are paged out otherwise it would fill up the physical RAM. However RSS doesn't decrease. If you read the entire file, then the RSS will end up bigger than the file, no matter how much physical RAM you have.

Is this correct ?

stas00 commented 1 year ago

All but the last sentence I think.

It won't be able to read into memory more data than there is actual free host memory. As those are real paged in data. If there is no free memory it can't read read it in.

The cgroups limited shell just shows that that that paged in cached MMAPed data doesn't count towards process' actually used memory.

In other words if your host has 8GB of RAM and 0 SWAP, and all of that RAM is free and your process is unlimited it'll only be able to page in 8 GB of MMAPed data.

rwightman commented 1 year ago

It's challenging to describe what happens in a simple manner because it's not so simple. It's system level behaviour that depends on the system wide configuration and all of the processes running. The OS virtual memory manager needs to manage the pages across all of the processes, incl the mmap'd files. Pages can stay in a while if there is little demand, or get paged out quickly if there is high demand (from possibly any of the other processes on the system). Random access patterns (ie need one entry here, one there, will require a new page in for each access, multiple pages if there is seq read-ahead enabled). If you have little system memory and are mapping large files, jumping around, you can easly end up thrashing (constantly paging in, and then ejecting, which results in disk IO for mmap'd files). So what you observe for a single process, what's resident, can change based on what other processes are doing as wel...

Many databases and applications that do heavy IO on large files opt to handle IO and buffer mangement themselves vs leave it up to the OS VMM. While mmap'd files are usually good enough, there are some pretty confusing edge case behaviours. Just loading the buffers directly into memory will be more consistent behaviour wise. However, that is only practical if the maximum file size is limited to a reasonable value (ie data is sharded into 256MB-2GB chunks).

stas00 commented 1 year ago

Indeed, this is all very complex and there is no deterministic way or tracking real memory usage of a process as a whole, especially so when different unrelated components are intermixed - e.g. python and 3rd party c++ libraries bound to it. Each component's memory allocations can be traced separately, e.g. tracemalloc for python, torch.cuda.memory* for torch, mem_pool.bytes_allocated for pyarrow, etc.

Instead of trying to explain all these complexities, I propose that all the lay user needs to know is this:

  1. If the host has free memory, the mmap mechanism will use that free memory for caching and will immediately release any or all of that cached memory if another process needs it. It will only hold onto the memory that is actually being used by the process.
  2. Because of that the reported shared memory can be alarmingly large and appear as a leaked memory. But this is not the case.

In my write up I have crafted a few set ups using code samples and cgroups-limited shells that demonstrate the above. So the above summary can be TLDR and the article can be the long story for those who care for an empirical demonstration.

If you have about 10min please kindly skim through my write up and let me know if you found some holes / unclear parts or whether this clears things up after reading it. Thank you!

stas00 commented 1 year ago

I asked around and it seemed that my write up won't fit into the HF blog, so I published it on my blog https://stasosphere.com/entrepreneur-being/301-mmap-memory-leak-investigation/ I added context and edited it some more.

@lhoestq, I am not sure if you still want to include any of my notes in the datasets docs or not. It's totally your call.

lhoestq commented 1 year ago

Ok :) I think the explanations in https://github.com/huggingface/datasets/issues/4883#issuecomment-1252722977 should be helpful for many users already. And maybe redirect to your blog for the details ? Then to debug how much Arrow data is physically in RAM we can also advise pa.total_allocated_bytes. What do you think ?

stas00 commented 1 year ago

That sounds like a good plan, @lhoestq!

plamb-viso commented 11 months ago

Wanted to add one method for dealing with issues in datasets/pyarrow that seem related to this thread (and i hope helps someone; I imagine they'll land here like i did). I was encoding image data on a fairly large dataset using datasets .map() in a vanilla way and running into mysterious pyarrow errors or no errors at all and simply hanging forever. When run locally on one processor, I was getting the same results as this thread (memory seemed to increase forever). I tried every solution I could find with no luck.

I came up with a simple way to reduce the pressure on datasets/pyarrow during mapping/index flattening while still getting some of the benefit of parallel processing; this method was able to encode and save a 229gb dataset on a system with much less memory than that.

In my case the dataset starts as rows with file paths; during the .map() phase, for each row (or training instance) the called function can open these file paths to perform encoding. Because of this, I can know the length of the dataset in advance.

This method is much slower than vanilla mapping, but enabled me to get past all the weird pyarrow issues and successfully encode a dataset that was much larger than the available memory on the instance it was running on.

rokayabencheikh commented 9 months ago

Wanted to add one method for dealing with issues in datasets/pyarrow that seem related to this thread (and i hope helps someone; I imagine they'll land here like i did). I was encoding image data on a fairly large dataset using datasets .map() in a vanilla way and running into mysterious pyarrow errors or no errors at all and simply hanging forever. When run locally on one processor, I was getting the same results as this thread (memory seemed to increase forever). I tried every solution I could find with no luck.

I came up with a simple way to reduce the pressure on datasets/pyarrow during mapping/index flattening while still getting some of the benefit of parallel processing; this method was able to encode and save a 229gb dataset on a system with much less memory than that.

In my case the dataset starts as rows with file paths; during the .map() phase, for each row (or training instance) the called function can open these file paths to perform encoding. Because of this, I can know the length of the dataset in advance.

  • Using the dataset length and passed mapper batch size, find a divisor of the dataset length that divides it evenly (in my case, if it happens to be prime, just subtract one training example off the end)
  • Using this divisor, iterate the dataset in chunks of this divisor size (call it a 'step' size).
  • The actual .map() batch_size/writer_batch_size can then be the step size divided by the number of processors with some sane minimum. For e.g. if you are dropping bad instances using batched=True you need a sane minimum so a given processor doesnt result in 0 instances. So basically you are mapping on a chunk of the dataset and dividing that chunk among the processors.
  • Start iterating on the dataset given the step size; pass that chunk to .map() with a batch_size/writer_batch_size that divides it among the processors
  • When a given chunk completes the .map() phase, call save_to_disk on it so it doesn't stay in memory or cause any weirdness with memory mapping.
  • Once all chunks have been encoded, iterate the directory where they were saved and call load_to_disk() on them (with keep_in_memory=False of course) and append them to a list
  • Call concatenate_datasets() on that list.

This method is much slower than vanilla mapping, but enabled me to get past all the weird pyarrow issues and successfully encode a dataset that was much larger than the available memory on the instance it was running on.

plz give us the python code

plamb-viso commented 9 months ago

I don't have time to generalize or really explain each step here, but maybe this will help someone. I recently used this technique on a dataset with ~5,000,000 sentences when straightforward dataset.map() failed in a way I've become familiar with. Note that the chosen config.batch_size will have a huge impact on performance since each chunk is getting saved to disk. You want to pick one thats basically as large as possible knowing its split among x processors, but not so large it can hit the memory weirdness with straightforward dataset.map()'ing. For e.g. on that 5 mil sentence dataset, a batch size of 50,000 worked very well.

        dataset_len = len(dataset)
        step = None
        for num in range(config.batch_size, dataset_len):
            if dataset_len % num == 0:
                step = num
                break
        if not step:
            dataset = dataset.select(range(0, dataset_len - 1))
            dataset_len = dataset_len - 1
            for num in range(config.batch_size, dataset_len):
                if dataset_len % num == 0:
                    step = num
                    break
        if not step:
            raise Exception(f"Could not find a step size for {dataset_len}")
        batch_size = int(step/num_processes) if int(step/num_processes) >= 200 else 200
        print(
            f'Initiating encoding via concatenating datasets. num_processes={num_processes}, keep_in_memory={config.keep_in_memory}, dataset_len={dataset_len}, step={step}, batch_size={batch_size}')
        for idx in range(0, dataset_len, step):
            print(f"Operating on {idx} -> {idx+step} of {dataset_len}")
            chunk = dataset.select(range(idx, idx+step))
            encoded_chunk = chunk.map(
                preprocess_data,
                fn_kwargs=dict(tokenizer=tokenizer, labels=labels, label2id=label2id, config=config,
                               batch_size=batch_size),
                keep_in_memory=config.keep_in_memory,
                batched=True,
                features=features,
                remove_columns=dataset.column_names,
                num_proc=num_processes,
                batch_size=batch_size,
                writer_batch_size=batch_size
            )
            print(f"Saving {idx} -> {idx+step} chunk to disk")
            encoded_chunk.save_to_disk(f"{SAGEMAKER_CHUNK_DIR}/chunk_{idx}_{idx+step}")

        print("Loading encoded chunks for concatenating")
        encoded_chunks = []
        parent_path, subdirs, parent_files = next(os.walk(SAGEMAKER_CHUNK_DIR))
        for chunk in subdirs:
            print(f"Loading chunk: {chunk}")
            curr_path = os.path.join(parent_path, chunk)
            encoded_chunks.append(load_from_disk(curr_path))
        print(f"Concatenating {len(encoded_chunks)} datasets")
        encoded_dataset = concatenate_datasets(encoded_chunks)
naarkhoo commented 7 months ago

I ended up splitting my DatasetDicts into N pieces

split_data = False
split_datadict = True

N = 50
total_rows = len(gtzan)
split_size = total_rows // N
remaining = total_rows % N  # To handle any remainder

from datasets import DatasetDict

if split_datadict:
    # Initialize a list to hold the new DatasetDicts
    split_dataset_dicts = [DatasetDict() for _ in range(N)]

    # Process each dataset in the DatasetDict
    for dataset_name, dataset in gtzan.items():
        total_rows = len(dataset)
        split_size = total_rows // N
        remaining = total_rows % N

        # Create the splits for this dataset
        start_idx = 0
        for i in range(N):
            end_idx = start_idx + split_size + (1 if i < remaining else 0)
            split = dataset.select(range(start_idx, end_idx))

            # Add this split to the corresponding new DatasetDict
            split_dataset_dicts[i][dataset_name] = split

            start_idx = end_idx

then run my function

# Initialize an empty list for the processed splits
processed_splits = []

# Loop through each split and apply the preprocessing
for cnt, split in enumerate(split_dataset_dicts):
    print(f"iteraction:{cnt}")
    processed_split = split.map(preprocess_function,
                                remove_columns=["audio"],
                                batched = True,
                                num_proc=4,
                                )

    processed_splits.append(processed_split)

interestingly enough list comprehension does not work here and hang too ! so go old/good fashion loop.

then concat the result

from datasets import concatenate_datasets, DatasetDict

# Assuming split_dataset_dicts is your list of split DatasetDicts

# Initialize a dictionary to hold the concatenated datasets
concatenated_datasets = {}

# Get all dataset names (keys) from the first split (assuming all splits have the same structure)
dataset_names = split_dataset_dicts[0].keys()

# Concatenate the datasets for each name
for name in dataset_names:
    # Collect the same part of each split
    datasets_to_concatenate = [split[name] for split in split_dataset_dicts]

    # Concatenate them
    concatenated_dataset = concatenate_datasets(datasets_to_concatenate)

    # Add to the dictionary
    concatenated_datasets[name] = concatenated_dataset

# Create a new DatasetDict with the concatenated datasets
concatenated_dataset_dict = DatasetDict(concatenated_datasets)