chanzuckerberg / cellxgene-census

CZ CELLxGENE Discover Census
https://chanzuckerberg.github.io/cellxgene-census/
MIT License
72 stars 19 forks source link

Consider adding implementing a data distribution sensitive shuffling algorithm for dataloader libraries #1146

Open prathapsridharan opened 1 month ago

prathapsridharan commented 1 month ago

There is potential an opportunity to implement a shuffling algorithm that closely approximates a random sample shuffle. The following is a recapitulation of slack discussion

The high level description of the algorithm is:

  1. divide the entire data corpus into chunks of chunk_size
  2. Randomly select N chunks
  3. Concatenate the N chunks in memory
  4. Shuffle the newly concatenated chunk

Let's call this method scatter_gather_shuffle. This method attempts to strike a balance between randomness and good I/O performance (hence reading chunks rather than individual data points which would be less efficient in terms of I/O)

The following algorithm is stated in a way such that, if accepted, it could potentially be packaged in a separate tensor library (like pytorch) or a dataloader library and thus be generally useful to many types of ML training workloads.

The following algorithm assumes that the data is uniformly distributed across some buckets. However, it is possible for the algorithm to take in a real distribution (a probability mass function) or an analytic distribution (ex: exponential distribution, poisson distribution, etc) and perform the expectation calculations based on the input data distribution. Also, the description of the algorithm is set in the context of cellxgene_census where the data is is bucketed by dataset_id, however, the algorithm is generally applicable for any dataset that naturally fall into buckets.

The central problem is to determine the number of random chunks to gather across the data and the size of each such chunk to then concatenate and shuffle so as to yield a sequence of data points that is satisfactorily random.

"Satisfactorily Random" is something that the user must define here. One definition of "satisfactorily random" that is simple to encode and generally useful is if the user knows how data points are bucketed, then a "satisfactorily random" sequence of K data_points would represent some desired fraction of the buckets. To put it more concretely, the census has 60 million observations (data_points) distributed across 567 datasets (the buckets). If K observations are drawn at random from the entire corpus what is the expected number of datasets covering these K random points? I think we could work that out analytically:

  1. Let D be the number of datasets and K be the number of random observations drawn from the data corpus
  2. Let X[i] be a random variable such that X[i] = 1 if dataset-i ends up with at least 1 observation and X[i] = 0 otherwise.
  3. Pr(X[i] = 0) = ((D - 1)/D)^K - The probability that dataset-i contains no observations after K random observations are drawn
  4. Pr(X[i] = 1) = 1 - ((D - 1)/D)^K - The probability that dataset-i contains at least 1 observation after K random observations are drawn
  5. Y = sum([X[i] for i in range(D)]) - the number of datasets with at least 1 observation
  6. Using the linearity of expectations, the expected count of datasets with at least 1 observation, E[Y] = sum([E[X[i]] for i in range(D)]) = D * (1 - ((D-1)/D)^K)

Thus for K = 500, expected number of datasets represented, E[Y] = 567 * (1 - (566/567)**500) = 332. If K = 2000, E[Y] = 550 (almost all datasets) - this is how I arrived at the 2000 random chunks. Since we want good I/O efficiency, dividing the memory budget (specified in number of rows) by the number of chunks gives us the chunk size: 128_000/2000 = 64

Pseudocode for the algorithm:

def heuristic_shuffle(
        data_records_index_keys: List[Any],
        num_buckets: int,
        min_bucket_fraction: float = 0.6,
        max_records_to_load: int = -1):
    """
    Args:
        data_records_index_keys:
            An order of data record keys that makes reading records in the order of the index
            very fast from the underlying storage medium. Typically this is a list of primary keys
            but any fast indexing scheme will work.

        num_buckets:
            This the variable with respect to which "satisfactorily random" is defined. Typically,
            non random data are bunched together into buckets and a "satisfactorily random" sequence
            will represent these buckets well - see `min_bucket_fraction`.

        min_bucket_fraction:
            The minimum fraction of buckets to represented in the random sequence of data points
            loaded in memory - see `max_records_to_load`.

        max_records_to_load:
            The maximum number random records to generate in memory. Specifying -1 means all records
            are loaded in memory and shuffled. Typically, this algorithm is used in memory constrained
            environments so a typical value will be in the tens or hundreds of thousands (ex: 128_000)

    """

    # calculate number of random chunks to read
    num_chunks = 1
    min_buckets_desired = min_bucket_fraction * num_buckets
    expected_num_buckets = num_buckets * (1 - ((num_buckets - 1)/num_buckets)**num_chunks)

    while not math.isclose(expected_num_buckets, min_buckets_desired):
        num_chunks *= 2 # could use a different factor to scale the num_chunks if desired
        expected_num_buckets = num_buckets * (1 - ((num_buckets - 1)/num_buckets)**num_chunks)

    chunk_size = max_records_to_load // num_chunks

    do_scatter_gather_shuffle(data_records_index_keys, shuffle_chunk_count=num_chunks, soma_chunk_size=chunk_size)
prathapsridharan commented 1 month ago

A specific comment in the above slack thread recapitulated here because it is useful:

K is the number of data_points. Now that data_point can be viewed as a single cell or chunk of contiguous rows. The chunk of contiguous rows conception is used here for I/O efficiency as a read query is made more efficient when you fetch more data in that read query.

The thinking is picking K random sections of the entire corpus allows us to get a certain desired coverage of datasets. With the chunks (or contiguous rows) concept you can think of the entire census split evenly into sections/chunks. Each section/chunk is then a data_point.

Even when you don't actively do any chunking, you can view the entire census as a collection of sections/chunks where each chunk is of length 1 row. But that would be inefficient in terms of I/O during reads from the storage medium. So think of each chunk of length > 1 (for I/O efficiency) and determine that chunk size based on the memory budget - number of actual rows that can be loaded in memory. Thus the chunk_size = memory_budget_num_rows // K

prathapsridharan commented 1 month ago

A calculation of expected number of datasets represented when K random datapoints _are selected from the census using the real distribution of datapoints (observations) across datasetids. This was calculated by @pablo-gar :

       K expected_n_datasets
1    500            197.9385
2   1000            285.0254
3   1500            342.0426
4   2000            383.4680
5   2500            415.3354
6   3000            440.7713
7   3500            461.6108
8   4000            479.0251
9   4500            493.8070
10  5000            506.5175
11  5500            517.5673
12  6000            527.2646
13  6500            535.8462
14  7000            543.4966
15  7500            550.3621
16  8000            556.5600
17  8500            562.1854
18  9000            567.3163
19  9500            572.0169
20 10000            576.3408
21 10500            580.3331
22 11000            584.0316
23 11500            587.4689
24 12000            590.6724
25 12500            593.6661
26 13000            596.4705
27 13500            599.1037
28 14000            601.5811
29 14500            603.9167
30 15000            606.1226
31 15500            608.2095
32 16000            610.1870
33 16500            612.0636
34 17000            613.8470
35 17500            615.5440
36 18000            617.1608
37 18500            618.7030
38 19000            620.1757
39 19500            621.5835
40 20000            622.9306
prathapsridharan commented 1 month ago

Here is a histogram of the distribution of observations across dataset-ids in the census (computed by @pablo-gar ):

image.png
prathapsridharan commented 1 month ago

The histogram in the previous comment shows on zenhub but not on this GH issue. Reattached in this comment as well:

image