stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
468 stars 70 forks source link

Possible Ray OOM when procesing long documents #568

Open versae opened 2 months ago

versae commented 2 months ago

The workers OOMs a few times during the tokenization of a dataset with very long documents (over 1M chars), but succeed in the end by adjusting batch size of BatchTokenizer and just retrying.

@dlwh:

yeah so i think what's happening is that ray creates 1 process per cpu on the node (even though we always schedule ~16 or so cpus per tokenization task), reusing those processes to process batches, Ray seems to do some kind of round-robin scheduling to these processes. This is fine and good, except HF tokenizers retains memory somehow in those processes (probably as an optimization?), and memory use seems to be directly related to the doc sizes. This means that we're retaining num_processes * whatever ram it is, and this ooms on TPU for large enough books if ray would just... reuse processes or not allocate so many processes it would be fine

Could this be the reason? https://github.com/stanford-crfm/levanter/blob/2516d06be6fec2ff5660f144649a9f5f577b06e9/src/levanter/data/text.py#L308-L312

It seems it will in most cases enable multithreading in Rust.

dlwh commented 1 month ago

@versae have you tried disabling and seeing if it fixes?

versae commented 1 month ago

Yes, I now set TOKENIZERS_PARALLELISM to false in my setup scripts. It seems to help, but not sure it is the definitive fix.

dlwh commented 1 month ago

interesting ok, I guess it's time to give up on that then. Do you reduce the batch size?

versae commented 1 month ago

Yes, for processing very very long documents (tens of millions of tokens) I had to set it to 1 and set TOKENIZERS_PARALLELISM to False. Slower, but at least it hasn't failed me yet. Is the batch size of the tokenizer something we can set in the config for the training?

versae commented 1 month ago

OK, I think I found a winning combination, setting SLURM_CPUS_ON_NODE=16 TOKENIZERS_PARALLELISM=false seems to work with the current batch size. On a TPUv4-32, 3 out of 4 nodes sometimes fail right after loading the weights, but the 1 that keeps running is able to finish the tokenization. So I just leave it running and when it's done I restart training without SLURM_CPUS_ON_NODE.