Currently, our data preprocessing story is something like:
Get a corpus (either hf datasets or, better, jsonl files)
ideally shuffle the corpus (yourself)
Run a single machine process to pretokenize the data into shards
The last step can take many hours to days, and is a pretty annoying bottleneck that hits users right at the start of using levanter. No good.
We need an alternative. In this issue, I focus on the last step. Dealing with the second is covered in #34.
Roadmap
[x] Ray Cluster for tokenizing across multiple machines, without support for simultaneous writing and reading #112
[x] introduce a coordinator for simultaneous reads and writes #112
[x] Add resumption for writing #112
[x] auto-discovery of cluster on Slurm/TPU (can maybe hijack Jax's stuff) #118
[x] log metrics during preprocessing #117
[x] wire up on-the-fly tokenization during training #120
[ ] data loading from the global order with perfect reproducibility #119
Proposed Approach
(Design as of 2023-04-18)
Goals
We want to support the following:
1) Deterministic batches, even for a changing number of readers (or writers). That is, for any cluster size
during training, we want the same batches to be generated in the same order.
2) Sharded reading and writing. We want to be able to read and write from multiple shards in parallel.
3) Simultaneous reading and writing of shards. We want to be able to start training while we are still building the cache.
4) Fast resumption without losing too much progress. This applies to both writing and reading the cache. That is, when we resume a training run, we want to finish producing the cache and also jump to the right place in the cache for reads.
5) (eventually) shuffling/random access
6) Takes advantage of the fact that we typically have idle, beefy CPUs on the machines where we're doing training
7) We want to be able to build the cache offline too.
8) We want to support batches that are composed of fragments of documents. In particular, we take a moving window of tokens from documents. This implies that the mapping from "documents" to "batches" is not 1:1, or easy to compute.
9) (eventually) ≈random access to tokens and not docs
10) can handle a variable number of examples being generated per input doc
We want to support the following use cases:
1) We have a larger training dataset, and we want to draw samples from it more or less independently on a large number of machines.
We don't really care about "epochs"/"passes", but we do want to be able to handle resumes and be deterministic. Ideally, each
machine only reads from the chunks that it needs to read from.
2) We have a smaller validation dataset, and we want to do a single pass over it. We don't care about resuming, and it's ok if
we have to read the whole dataset on each machine.
3) Like (1) but we want to jump around the dataset. We still care about resuming and determinism, but don't care about epochs.
We focus on (1) and (2) for now.
Some terminology
Shard: A shard is a list of raw documents that not been tokenized/preprocessed.
Chunk: A chunk is a list of processed documents that have been tokenized/preprocessed.
Reader: A reader is a process that reads from the cache. Typically there is one reader per machine.
Writer: A writer is a process that writes to the cache. Typically there is one writer per machine.
Global ordering: The global ordering is the ordering of chunks in the cache. This is the order in which
documents are read by readers. The global ordering is defined with respect to an "idealized" number of readers R*. (See below.)
Processor or Tokenizer: A function that takes a raw document and returns a processed document.
Example is a single datum that is fed into the model. Examples are typically composed of fragments of documents.
For example, we might take a moving window of tokens from the concatenation of a list of preprocessed documents.
We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R.
We produce N chunks. We also define an idealized number of readers R, which defines the global ordering over the data.
Typically R should be the maximum number of readers we expect to actually use.
Cache structure
We define a shard cache as a list of "chunks", where each chunk is a parquet file (plus metadata) with an equal
number of documents (except for the last chunks for each shard.)
Each chunk is a list of processed documents. Chunks are ordered round robin from the input shards, so that the c'th global chunk is the
c%K'th chunk of the c/K'th shard, so long as all shards have at least c/K chunks. (After that, we remove shards that
have been exhausted and continue round robin.)
We keep the following metadata:
For each shard, we keep a list of chunks written so far and whether or not we are done processing that shard.
For each chunk, we keep the number of documents, token counts/length of various fields, and the number of bytes.
(This metadata can be used for seeking.)
For the cache overall, we keep the global ordering of chunks, the number of chunks, and the number of documents.
Chunk format
A Chunk is an Apache Parquet file with schema dependent on the task. For example, for language modeling, we might have
just a sequence of input_ids per document. We use Apache Parquet because it's compact and doesn't require us to know
much about the datatypes we're using.
Chunks also have metadata stored in a separate json file. This metadata includes the total number of documents in the
chunk, as well as token counts/lengths of various fields. This metadata is used for seeking.
Cache construction
We use Ray to manage the writers. Readers are managed by the main processes (though call into Ray to get the data).
At a high level, we create a writer process for each shard, which produce chunks one by one. T is a central
writer coordinator process that receives chunks from each shard and adds them to the global ordering round robin.
When chunks are added, we make them available to an actor that readers can access to get chunks.
Reproducible Sharded Reading for Training
We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers
changes. We also want readers to only read from the chunks that they need to read from.
We pretend the list of data is infinite by cycling. We cannot track epochs.
NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents.
Given a list of chunks and the idealized number of readers R, we define the global ordering over chunks as follows:
First define R iterators over chunks, with chunk_iterators[r] being defined as loop(all_chunks)[r::R*].
Next, define a function mk_examples(chunk_iterator) that takes a list of iterators over chunks and returns
a list of examples. Define chunk_examples[r] = mk_examples(chunk_examples[r]).
This function depends on our sequence length, etc. Then the ordering over examples is:
chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ...
that is, example[i] == chunk_examples[i % R*][i // R*]
If we have $R$ readers, then each `reader_iterator[r][j] == chunk_examples[r][j] == example[j R + r]`.
Moreover, if either R or R is a multiple of the other, then we still get a nice property where
each reader reads from a strided slice of the chunk_iterators:
(Boring math)
If we have R readers, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*]
If we have R == n * R*, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*], so each reader reads from
a strided slice (specifically islice(..., r//R*, None, n))
If we have R* == n * R, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[R * (j % n) + r][(j * R + r) // R*] and so each reader reads from n different chunk_exampless.
so we round-robin over a slice of the chunk_exampless.
For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators
and you end up reading from everywhere, but that's ok.
Single-Pass Reading for Evaluation
When we want to do a single pass over the data, we don't cycle and we don't shuffle. We just read the data in order. Boring
and simple.
Resuming
We need to think about resuming in two cases: resuming writes and resuming reads.
Resuming Writes
Resuming writes is relatively easy, since we can just keep track of the number of chunks written for each shard and the
number of documents written for each chunk. Then you just skip to the appropriate document and start writing.
Resuming Reads
We want to understand how to seek to the b'th batch.
There are two cases of resuming we need to think about:
1) The "easy" case where 1 example == 1 (preprocessed) document.
2) The "hard" case where the mapping from examples to documents is not 1:1, but there is some easily computable relationship.
In the first case, each reader r reads documents[r::R]. The bth batch
is documents[b * batch_size:(b+1) * batch_size]. Assuming batch_size % R == 0, then for the b'th batch, reader r
needs to read documents[b * batch_size + r: (b+1) * batch_size + r: R] == docs(chunk_iterator[r])[b * batch_size // R:(b+1) * batch_size // R].
If we know how many documents are in each chunk, then we can seek to the right place in the chunk.
The second case is broadly similar. In particular, we consider the case where we take moving windows of concatenated documents.
If our metadata includes token counts, then we can skip chunks until we pass batch_size * tokens_per_example // R tokens.
Shuffling
A brief digression
Why do we shuffle in machine learning Shuffling reduces variance in the gradients. If we have batches
where every example is from the same document/domain, then the gradients for those batches will be correlated.
That said, in our setting where we use moving windows from documents, if we round-robin from chunks (which are produced
from different documents), and R* is roughly equal to the batch size, then we will read from a different chunk for every
example in a batch, which reduces correlation within a batch.
However, we still have (undesirable) correlation between batches: if we
read from chunks consecutively and our documents are long, then many examples in the next batch will be from the
same document as an example in the previous batch. Ideally this wouldn't happen. I'm not convinced that it matters
that much.
Proper shuffling is incompatible with streaming at a fundamental level. Our choices are something like:
Randomly shuffle before preprocessing. Makes life a bit less pleasant for people with a new dataset. Can't be changed after preprocessing. Doesn't solve the problem of correlated batches.
Reservoir sampling. Makes resumes hard, but is easy to implement.
"Epochal" reservoir sampling, where we periodically "flush" the reservoir. Resumes are easier because you can start from the latest "epoch"
No shuffling in the first pass, but shuffle in subsequent passes.
Shuffle within a range of chunks that grows as the run progresses.
My hunch is that we can skip this for now, and revisit if we find that it's a problem.
Currently, our data preprocessing story is something like:
The last step can take many hours to days, and is a pretty annoying bottleneck that hits users right at the start of using levanter. No good.
We need an alternative. In this issue, I focus on the last step. Dealing with the second is covered in #34.
Roadmap
Proposed Approach
(Design as of 2023-04-18)
Goals
We want to support the following: 1) Deterministic batches, even for a changing number of readers (or writers). That is, for any cluster size during training, we want the same batches to be generated in the same order. 2) Sharded reading and writing. We want to be able to read and write from multiple shards in parallel. 3) Simultaneous reading and writing of shards. We want to be able to start training while we are still building the cache. 4) Fast resumption without losing too much progress. This applies to both writing and reading the cache. That is, when we resume a training run, we want to finish producing the cache and also jump to the right place in the cache for reads. 5) (eventually) shuffling/random access 6) Takes advantage of the fact that we typically have idle, beefy CPUs on the machines where we're doing training 7) We want to be able to build the cache offline too. 8) We want to support batches that are composed of fragments of documents. In particular, we take a moving window of tokens from documents. This implies that the mapping from "documents" to "batches" is not 1:1, or easy to compute. 9) (eventually) ≈random access to tokens and not docs 10) can handle a variable number of examples being generated per input doc
We want to support the following use cases: 1) We have a larger training dataset, and we want to draw samples from it more or less independently on a large number of machines. We don't really care about "epochs"/"passes", but we do want to be able to handle resumes and be deterministic. Ideally, each machine only reads from the chunks that it needs to read from. 2) We have a smaller validation dataset, and we want to do a single pass over it. We don't care about resuming, and it's ok if we have to read the whole dataset on each machine. 3) Like (1) but we want to jump around the dataset. We still care about resuming and determinism, but don't care about epochs.
We focus on (1) and (2) for now.
Some terminology
We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R. We produce N chunks. We also define an idealized number of readers R, which defines the global ordering over the data. Typically R should be the maximum number of readers we expect to actually use.
Cache structure
We define a shard cache as a list of "chunks", where each chunk is a parquet file (plus metadata) with an equal number of documents (except for the last chunks for each shard.) Each chunk is a list of processed documents. Chunks are ordered round robin from the input shards, so that the c'th global chunk is the c%K'th chunk of the c/K'th shard, so long as all shards have at least c/K chunks. (After that, we remove shards that have been exhausted and continue round robin.) We keep the following metadata:
Chunk format
A Chunk is an Apache Parquet file with schema dependent on the task. For example, for language modeling, we might have just a sequence of input_ids per document. We use Apache Parquet because it's compact and doesn't require us to know much about the datatypes we're using.
Chunks also have metadata stored in a separate json file. This metadata includes the total number of documents in the chunk, as well as token counts/lengths of various fields. This metadata is used for seeking.
Cache construction
We use Ray to manage the writers. Readers are managed by the main processes (though call into Ray to get the data). At a high level, we create a writer process for each shard, which produce chunks one by one. T is a central writer coordinator process that receives chunks from each shard and adds them to the global ordering round robin. When chunks are added, we make them available to an actor that readers can access to get chunks.
Reproducible Sharded Reading for Training
We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers changes. We also want readers to only read from the chunks that they need to read from. We pretend the list of data is infinite by cycling. We cannot track epochs.
NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents.
Given a list of chunks and the idealized number of readers R, we define the global ordering over chunks as follows: First define R iterators over chunks, with
chunk_iterators[r]
being defined asloop(all_chunks)[r::R*]
.Next, define a function
mk_examples(chunk_iterator)
that takes a list of iterators over chunks and returns a list of examples. Definechunk_examples[r] = mk_examples(chunk_examples[r])
. This function depends on our sequence length, etc. Then the ordering over examples is:chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ...
that is,example[i] == chunk_examples[i % R*][i // R*]
If we have $R$ readers, then each `reader_iterator[r][j] == chunk_examples[r][j] == example[j R + r]`. Moreover, if either R or R is a multiple of the other, then we still get a nice property where each reader reads from a strided slice of the chunk_iterators:
(Boring math) If we have R readers, then
reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*]
If we haveR == n * R*
, thenreader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*],
so each reader reads from a strided slice (specificallyislice(..., r//R*, None, n)
) If we haveR* == n * R
, thenreader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[R * (j % n) + r][(j * R + r) // R*]
and so each reader reads from n different chunk_exampless. so we round-robin over a slice of the chunk_exampless.For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators and you end up reading from everywhere, but that's ok.
Single-Pass Reading for Evaluation
When we want to do a single pass over the data, we don't cycle and we don't shuffle. We just read the data in order. Boring and simple.
Resuming
We need to think about resuming in two cases: resuming writes and resuming reads.
Resuming Writes
Resuming writes is relatively easy, since we can just keep track of the number of chunks written for each shard and the number of documents written for each chunk. Then you just skip to the appropriate document and start writing.
Resuming Reads
We want to understand how to seek to the b'th batch.
There are two cases of resuming we need to think about:
1) The "easy" case where 1 example == 1 (preprocessed) document. 2) The "hard" case where the mapping from examples to documents is not 1:1, but there is some easily computable relationship.
In the first case, each reader
r
readsdocuments[r::R]
. Theb
th batch isdocuments[b * batch_size:(b+1) * batch_size]
. Assumingbatch_size % R == 0
, then for the b'th batch, reader r needs to readdocuments[b * batch_size + r: (b+1) * batch_size + r: R] == docs(chunk_iterator[r])[b * batch_size // R:(b+1) * batch_size // R]
. If we know how many documents are in each chunk, then we can seek to the right place in the chunk.The second case is broadly similar. In particular, we consider the case where we take moving windows of concatenated documents. If our metadata includes token counts, then we can skip chunks until we pass
batch_size * tokens_per_example // R
tokens.Shuffling
A brief digression
Why do we shuffle in machine learning Shuffling reduces variance in the gradients. If we have batches where every example is from the same document/domain, then the gradients for those batches will be correlated.
That said, in our setting where we use moving windows from documents, if we round-robin from chunks (which are produced from different documents), and R* is roughly equal to the batch size, then we will read from a different chunk for every example in a batch, which reduces correlation within a batch.
However, we still have (undesirable) correlation between batches: if we read from chunks consecutively and our documents are long, then many examples in the next batch will be from the same document as an example in the previous batch. Ideally this wouldn't happen. I'm not convinced that it matters that much.
Proper shuffling is incompatible with streaming at a fundamental level. Our choices are something like:
My hunch is that we can skip this for now, and revisit if we find that it's a problem.