mittagessen / kraken

OCR engine for all the languages
http://kraken.re
Apache License 2.0
724 stars 130 forks source link

Training on large binary dataset #509

Closed colibrisson closed 1 year ago

colibrisson commented 1 year ago

I try to train a recognition model on a very large dataset. In order to speed up training, I have compiled my dataset into multiple arrow files of 10 000 lines each.

I initially assumed that ArrowIPCRecognitionDataset was just keeping in memory a reference of the arrow files. However, when I looked at ArrowIPCRecognitionDataset's add method, I realized that the whole dataset is loaded in memory before training start:

https://github.com/mittagessen/kraken/blob/a0c395727c011d3283b34b5f7a9ef6d85970e6d0/kraken/lib/dataset/recognition.py#LL191C1-L195C37

Wouldn't it be better to load the arrow files one by one during training?

mittagessen commented 1 year ago

Nope, they're aren't loaded into memory if a zero-copy concatenation is possible, i.e. if the table can be memory mapped. This principal case where this isn't possible is when using built-in fixed splits because filtering a table requires it to be loaded in its entirety.

In any case we need to iterate over the whole dataset before training to establish the alphabet after text transformations and that's where you might see some memory use but that's transient.

colibrisson commented 1 year ago

In any case we need to iterate over the whole dataset before training to establish the alphabet after text transformations and that's where you might see some memory use but that's transient.

It's not what I'm talking about.

Nope, they're aren't loaded into memory if a zero-copy concatenation is possible, i.e. if the table can be memory mapped. This principal case where this isn't possible is when using built-in fixed splits because filtering a table requires it to be loaded in its entirety.

It's weird. My dataset doesn't contain any split but the table seems to get loaded in memory:

(Pdb) print(sys.getsizeof(self.arrow_table) * 0.000001)
12197.935502

This corresponds to the process memory consumption displayed by top. I tried with multiple datasets and always get the same issue.

Please reopen the issue.

mittagessen commented 1 year ago

You can't use either sys.getsizeof() or top to determine the actual memory use of an object/process that uses memory mapping because the former prints whatever __sizeof__ returns and the latter shows virtual memory. Concatenation of tables doesn't result in them being loaded completely into memory. Look at the size of the arrow memory pool to convince yourself:

import sys
import pyarrow as pa
file1 = 'foo.arrow'
file2 = 'bar.arrow'
with pa.memory_map(file1, 'rb') as source:
    ds_table_1 = pa.ipc.open_file(source).read_all()
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(ds_table_1)}')
with pa.memory_map(file2, 'rb') as source:
    ds_table_2 = pa.ipc.open_file(source).read_all()
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(ds_table_2)}')
arrow_table = pa.concat_tables([ds_table_1, ds_table_2])
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(arrow_table)}')
colibrisson commented 1 year ago

Thanks a lot! I will check that out tonight.

Maybe I'm not using the best way to measure memory consumption for memory mapped objects. However, the memory overflow errors I'm getting from the host (not the GPU) are not virtual. So I think there's issue somewhere.

mittagessen commented 1 year ago

The skip_empty_lines switch might be the culprit here. The filter probably causes everything to get loaded. If you're sure there aren't any lines with single whitespace and similar in there you can just disable it.

colibrisson commented 1 year ago

773cc00 solved the issue, thanks!

mittagessen commented 1 year ago

Ok, perfect. It's just a shortcut that doesn't run the filter if there aren't any empty lines but as soon there's even a single one everything will get loaded again. I might have to filter manually instead of using pyarrow functions to avoid this.