stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
512 stars 81 forks source link

On-The-Fly Caching (Tracking Issue) #99

Closed dlwh closed 1 year ago

dlwh commented 1 year ago

Currently, our data preprocessing story is something like:

The last step can take many hours to days, and is a pretty annoying bottleneck that hits users right at the start of using levanter. No good.

We need an alternative. In this issue, I focus on the last step. Dealing with the second is covered in #34.

Roadmap

Proposed Approach

(Design as of 2023-04-18)

Goals

We want to support the following: 1) Deterministic batches, even for a changing number of readers (or writers). That is, for any cluster size during training, we want the same batches to be generated in the same order. 2) Sharded reading and writing. We want to be able to read and write from multiple shards in parallel. 3) Simultaneous reading and writing of shards. We want to be able to start training while we are still building the cache. 4) Fast resumption without losing too much progress. This applies to both writing and reading the cache. That is, when we resume a training run, we want to finish producing the cache and also jump to the right place in the cache for reads. 5) (eventually) shuffling/random access 6) Takes advantage of the fact that we typically have idle, beefy CPUs on the machines where we're doing training 7) We want to be able to build the cache offline too. 8) We want to support batches that are composed of fragments of documents. In particular, we take a moving window of tokens from documents. This implies that the mapping from "documents" to "batches" is not 1:1, or easy to compute. 9) (eventually) ≈random access to tokens and not docs 10) can handle a variable number of examples being generated per input doc

We want to support the following use cases: 1) We have a larger training dataset, and we want to draw samples from it more or less independently on a large number of machines. We don't really care about "epochs"/"passes", but we do want to be able to handle resumes and be deterministic. Ideally, each machine only reads from the chunks that it needs to read from. 2) We have a smaller validation dataset, and we want to do a single pass over it. We don't care about resuming, and it's ok if we have to read the whole dataset on each machine. 3) Like (1) but we want to jump around the dataset. We still care about resuming and determinism, but don't care about epochs.

We focus on (1) and (2) for now.

Some terminology

We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R. We produce N chunks. We also define an idealized number of readers R, which defines the global ordering over the data. Typically R should be the maximum number of readers we expect to actually use.

Cache structure

We define a shard cache as a list of "chunks", where each chunk is a parquet file (plus metadata) with an equal number of documents (except for the last chunks for each shard.) Each chunk is a list of processed documents. Chunks are ordered round robin from the input shards, so that the c'th global chunk is the c%K'th chunk of the c/K'th shard, so long as all shards have at least c/K chunks. (After that, we remove shards that have been exhausted and continue round robin.) We keep the following metadata:

Chunk format

A Chunk is an Apache Parquet file with schema dependent on the task. For example, for language modeling, we might have just a sequence of input_ids per document. We use Apache Parquet because it's compact and doesn't require us to know much about the datatypes we're using.

Chunks also have metadata stored in a separate json file. This metadata includes the total number of documents in the chunk, as well as token counts/lengths of various fields. This metadata is used for seeking.

Cache construction

We use Ray to manage the writers. Readers are managed by the main processes (though call into Ray to get the data). At a high level, we create a writer process for each shard, which produce chunks one by one. T is a central writer coordinator process that receives chunks from each shard and adds them to the global ordering round robin. When chunks are added, we make them available to an actor that readers can access to get chunks.

Reproducible Sharded Reading for Training

We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers changes. We also want readers to only read from the chunks that they need to read from. We pretend the list of data is infinite by cycling. We cannot track epochs.

NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents.

Given a list of chunks and the idealized number of readers R, we define the global ordering over chunks as follows: First define R iterators over chunks, with chunk_iterators[r] being defined as loop(all_chunks)[r::R*].

Next, define a function mk_examples(chunk_iterator) that takes a list of iterators over chunks and returns a list of examples. Define chunk_examples[r] = mk_examples(chunk_examples[r]). This function depends on our sequence length, etc. Then the ordering over examples is:

chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ... that is, example[i] == chunk_examples[i % R*][i // R*]

If we have $R$ readers, then each `reader_iterator[r][j] == chunk_examples[r][j] == example[j R + r]`. Moreover, if either R or R is a multiple of the other, then we still get a nice property where each reader reads from a strided slice of the chunk_iterators:

(Boring math) If we have R readers, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] If we have R == n * R*, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*], so each reader reads from a strided slice (specifically islice(..., r//R*, None, n)) If we have R* == n * R, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[R * (j % n) + r][(j * R + r) // R*] and so each reader reads from n different chunk_exampless. so we round-robin over a slice of the chunk_exampless.

For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators and you end up reading from everywhere, but that's ok.

Single-Pass Reading for Evaluation

When we want to do a single pass over the data, we don't cycle and we don't shuffle. We just read the data in order. Boring and simple.

Resuming

We need to think about resuming in two cases: resuming writes and resuming reads.

Resuming Writes

Resuming writes is relatively easy, since we can just keep track of the number of chunks written for each shard and the number of documents written for each chunk. Then you just skip to the appropriate document and start writing.

Resuming Reads

We want to understand how to seek to the b'th batch.

There are two cases of resuming we need to think about:

1) The "easy" case where 1 example == 1 (preprocessed) document. 2) The "hard" case where the mapping from examples to documents is not 1:1, but there is some easily computable relationship.

In the first case, each reader r reads documents[r::R]. The bth batch is documents[b * batch_size:(b+1) * batch_size]. Assuming batch_size % R == 0, then for the b'th batch, reader r needs to read documents[b * batch_size + r: (b+1) * batch_size + r: R] == docs(chunk_iterator[r])[b * batch_size // R:(b+1) * batch_size // R]. If we know how many documents are in each chunk, then we can seek to the right place in the chunk.

The second case is broadly similar. In particular, we consider the case where we take moving windows of concatenated documents. If our metadata includes token counts, then we can skip chunks until we pass batch_size * tokens_per_example // R tokens.

Shuffling

A brief digression

Why do we shuffle in machine learning Shuffling reduces variance in the gradients. If we have batches where every example is from the same document/domain, then the gradients for those batches will be correlated.

That said, in our setting where we use moving windows from documents, if we round-robin from chunks (which are produced from different documents), and R* is roughly equal to the batch size, then we will read from a different chunk for every example in a batch, which reduces correlation within a batch.

However, we still have (undesirable) correlation between batches: if we read from chunks consecutively and our documents are long, then many examples in the next batch will be from the same document as an example in the previous batch. Ideally this wouldn't happen. I'm not convinced that it matters that much.

Proper shuffling is incompatible with streaming at a fundamental level. Our choices are something like:

My hunch is that we can skip this for now, and revisit if we find that it's a problem.

dlwh commented 1 year ago

gonna close this since only 1 item is left and it's nice to have