foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
114 stars 18 forks source link

A revisit on improving the performance of Data Loader #70

Open lchu-ibm opened 3 months ago

lchu-ibm commented 3 months ago

We have been noticing a slowdown on training that was introduced by our dataloader. Upon further checking, we identified the issue coming from the fact that our dataset class is maintaining a bunch of very large lists.

Background

Each logical shard maintains a list of (dataset_id, shard_id, doc_id) in order to track the document. e.g. ("c4", 3, 110) refers to the 110th document inside the file dataset_root_folder/dataset=c4/xxx.part3.arrrow. When we distribute billions of documents over the thousands of logical shard workers, each logical shard worker gets such a list of millions of (dataset_id, shard_id, doc_id) tuples. So in total we are maintaining hundreds of GBs worth of lists internally.

And why we did this at first place? datasets are assumed not shuffled and thus we need to shuffle our billions of (dataset_id, shard_id, doc_id), so each logical shards maintains a shuffled list that contains millions of such tuples. Such kind of list has to be materialized at certain point (even we do lazy init or something similar) in order to have our dataloader stateful - we need to know and checkpoint exactly which documents are visited and which are to be visited and in what order, so that we can recover a training flawlessly in a deterministic fashion.

Solution

If we peel the onion here completely, the question actually boils down to: how can we maintain a list that: is truly stateful, provides random reading, and provides easy checkpointing and recovery.
This leads us to leverage LCG (Linear congruential generator) and utilize the "stateful-ness" of LCG to achieve the stateful-ness of the list.

A quick overview of the LCG we built for an arbitrary sized list:

# ZX81, cc65, Knuth and H. W. Lewis
LCG_PARAMS = [(2 ** 16 + 1, 75, 74), (2 ** 23, 65793, 4282663), (2 ** 32, 1664525, 1013904223)]

class LCG:
    def __init__(self, size, seed=42):
        self.size = size
        self.state = seed
        for params in LCG_PARAMS:
            if size <= params[0]:
                self.m, self.a, self.c = params
                break

    def _next(self):
        self.state = (self.a * self.state + self.c) % self.m
        return self.state

    def next(self):
        while True:
            res = self._next()
            if res < self.size:
                return res

and validation:

selector = LCG(1000000)
res = [selector.next() for _ in range(1000000)]
expected = list(range(1000000))
assert sorted(res) == expected
lchu-ibm commented 3 months ago

@nairbv @thoangtrvn @JRosenkranz

daviswer commented 3 months ago

Stateless Implementation

Although the LCG provides the desired random permutation, this approach introduces extra state to be tracked (our position in the recursively-generated permutation sequence, and/or our position in the shard file). A much cleaner implementation is to use the LCG as a stateless, randomized bijective map from a contiguous range of doc indices to a shuffled, noncontiguous range of doc indices.

We can do this by leveraging the fact that the state of the LCG above is always set to the last emitted value. Since the LCG emits every value in the desired range exactly once per cycle, each seeded by the previous, we can instead simply re-seed the LCG every time with a position index argument, and it will hash that index to a new position with guaranteed no collisions. So at runtime we can simply iterate sequentially through the range of documents in a file shard owned by a given worker (possibly a subset of the full shard), and LCG will provide a map to a new, shuffled, noncontiguous set of documents.

Pros: Introduces no extra state to track, avoids materializing any long shuffled lists of position indices. Allows workers to now perform non-contiguous partitioning of shard files, in cases where files are split over multiple workers.

Cons: Produces similar shuffles across different shard files. The algorithm for finding the bijective mapping provided by LCG for a given index is: "Take the length-m cycle of indices produced by the given choice of m (2^16+1, 2^23, 2^32 above), find the given index, and proceed through the cycle until you land on a new index below your size threshold". This means that two shard files with the same number of documents will receive the same mapping, since they are stepping through the same cycle, the same way. Furthermore, two shard files of size m1, m2 with m2>m1 will also have the same mapping, up to insertion of the new indices greater than m1 and smaller than m2. Thus our LCG mapping is clearly less random than the original shuffled doc list implementation.