stanford-crfm / mistral

Mistral: A strong, northwesterly wind: Framework for transparent and accessible large-scale language model training, built with Hugging Face 🤗 Transformers.
Apache License 2.0
555 stars 49 forks source link

Streaming data for larger datasets #126

Closed jthickstun closed 2 years ago

jthickstun commented 2 years ago

Is your feature request related to a problem? Please describe. Currently, the entire dataset is loaded into memory during preprocessing via get_auto_dataset. Using multiple devices (nproc_per_node), the dataset seems to be replicated per-worker. This requires, e.g., 4x40 = 160Gb or RAM to train with 4 processors on openwebtext, which is not always available on machines optimized for gpu acceleration.

Describe the solution you'd like A data streaming solution that doesn't load the entire dataset into memory. Alternatively (maybe this is easier) a caching solution that allows for preprocessing on a high-memory machine and later training in a lower memory environment.

Describe alternatives you've considered I tried preprocessing on a high memory machine and then using the cached dataset in my training environment. But the chain of commands in get_auto_dataset requires loading the intermediate results:

tokenized_dataset = dataset.map(... lm_dataset = tokenized_dataset.map(...

Please note that I am unsure whether short-circuiting these steps with a final preprocessed version of the dataset would solve the problem: it is possible that the full final preprocessed dataset is still ultimately loaded into memory, in which case a solution to this problem would require not just altering the preprocessing logic but also the training loop to stream data from disk rather than from RAM.

dlwh commented 2 years ago

Yeah I've noticed this as well. I recently opened #116 as a kind of placeholder issue for tracking this, though it doesn't explicitly hit on the point of not keeping data in ram.

I believe the right thing to do is to use a dataloader for tokenizing and grouping rather than what we're doing, but I'm not 100% sure.

jthickstun commented 2 years ago

Maybe some thought is required re: details of streaming. When we tokenize the whole dataset up front we can split it up into seq_len sized chunks. Without pre-tokenizing, streaming might look something like:

  1. seek to a random byte offset in the dataset file
  2. skip to beginning of next token (does PreTrainedTokenizer expose this functionality?)
  3. tokenize context_length tokens starting from this location (not sure if this functionality is exposed either)

It could be annoying to get the Tokenizer interface to do this. It also changes the sampling strategy: getting chunks with random offsets instead of fixed chunks, and sampling with replacement instead of sampling without replacement. I dimly recall reading some work that suggests it is better to do a complete pass through your data (sampling without replacement) rather a randomized approach (I can't find the reference now; it was an analysis of negative correlations).

It might be better/safer/easier to stick to the current pre-processing strategy (maybe smarter about downloading/storing a single preprocessed copy) and add some logic for async streaming chunks of the pre-processed dataset from disk rather than loading the whole dataset into memory. But I could be biased towards solving my particular problem (limited RAM) in a minimally invasive way and missing the bigger picture of streaming.

dlwh commented 2 years ago

So, I think it's actually not too bad to do "good enough" and is implemented in, e.g., MosaicML's composer library, and basically looks like just turning what we have into streaming. The trick is how do you split into seq_len chunks? Well, we actually just need to do it the same way it is already implemented in Mistral: load batches of 1000 docs, and split those into seq_len chunks (concatenating as necessary). You can get away with dropping the the last little bit from each chunk. If that's done in an on-demand/streaming way, I think it'll have good enough GPU-utilization properties while not using too much ram.

I think a better solution looks somewhat like you suggest, though I think it's not worth trying to do random disk seeks (maybe a one-time shuffle, but I dunno). @siddk and I talked a little about this, and I think the solution looks something like: scan through the documents, keeping a buffer of "extra" tokens per file, tokenize the file, holding onto byte offsets, and record those (and file id) in a ledger. Then resumes and reproducibility are pretty simple.

I'd be very surprised if HF supported scanning to the next token from an arbitrary byte in an efficient way. There are typically multiple tokenizations that are compatible with a BPE-set of tokens (BPE just picks the greedy one). There's actually a neat paper that exploits this property of BPE: https://arxiv.org/abs/1910.13267

dlwh commented 2 years ago

It turns out to be trickier than I thought because HF datasets is super fragile.

Could you try flipping the keep_in_memory flag here to false? https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py#L39

dlwh commented 2 years ago

And could you also try out this branch https://github.com/stanford-crfm/mistral/compare/stream_dataset?expand=1 with the --dataset.streaming=true flag?

jthickstun commented 2 years ago

The keep_in_memory flag works for me, where I have a pre-processed cache of the dataset files (which I had previously generated on a machine with plenty of RAM) and I just need to train using this cache without loading it all into memory.

Edit: oops! I thought this had worked, but my job died after running for awhile. I suspect that this flag does lazy dataset loading but eventually as datapoints are accessed the full dataset ends up in memory?

dlwh commented 2 years ago

the core bug here has been fixed. Still no streaming but I'll track that in #116