xarray-contrib / xbatcher

Batch generation from xarray datasets
https://xbatcher.readthedocs.io
Apache License 2.0
167 stars 27 forks source link

initial prefetch for simple single chunked dim #161

Open ljstrnadiii opened 1 year ago

ljstrnadiii commented 1 year ago

POC Prefetch Generator:

This is a draft pr to articulate one possible approach to "prefetching" dask arrays or xarray arrays with dask.

The goals were to simultaneously:

I also tried one approach using a Queue on the workers. This felt weird and found myself reinventing features that dask already has.

Results

Using helm to deploy a cluster on kubernetes with 8 workers (4cpu and 16gb each and relatively standard network configurations), I am able to see:

What Next?

No clue. I would like to investigate

note:

ljstrnadiii commented 1 year ago

In another attempt to simplify an example and profile transferring data from multiple workers to a single worker (where an ml tasks would iterate over batches) I have created this example:

from more_itertools import chunked
import dask.array as da
from distributed import Client, get_client, wait
import seaborn as sns
import pandas as pd

# 16 workers (4cores, 16gb each) on kubernetes; no special network configuration
client = Client("...")                  

chunk_gbps = {}
for max_gather in [.5e9, 1e9, 2e9, 4e9, 6e9]:
    for chunk in [128, 256, 512, 1028, 2048]:
        print(f"looking at chunk {chunk}, {max_gather / 1e9}")
        _ = client.restart()
        array = da.random.random((25000, 100, 100, 9), chunks=(chunk, 100, 100, 9)).persist()
        wait(array)
        del client.datasets['test']
        client.publish_dataset(test=array)

        # determine block batch size to control transfered data with gather
        ex = array.blocks[0]
        batch_size = max(int(np.floor(max_gather / ex.nbytes)), 1)

        def compute_bytes():
            client = get_client()
            array = client.get_dataset('test')
            blocks = list(array.blocks)
            nbytes = 0
            t0 = time.time()
            for block_batch in chunked(blocks, batch_size):
                fs = [client.compute(b) for b in block_batch]
                arrays = client.gather(fs)
                for array in arrays:
                    nbytes += array.nbytes
            elapsed = time.time() - t0
            return (nbytes / elapsed) / 1e9

        # blocks = client.submit(get_blocks, pure=False)
        f = client.submit(compute_bytes, pure=False)
        chunk_gbps[(max_gather / 1e9, chunk, batch_size)] = f.result()

# plot for some trends
data = [(*k,v) for k,v in chunk_gbps.items()]
df = pd.DataFrame(data, columns=['gb_gather','chunk_size', 'actul_batch', 'gbps'])
sns.lineplot(x="gb_gather", y="gbps",hue="chunk_size",data=df)
Screen Shot 2023-01-18 at 10 29 38 AM
ljstrnadiii commented 1 year ago

@jhamman @maxrjones this is sort of the approach am considering developing.

I think 2gbps should be fine, but I was able to get 8+gbps with https://github.com/NVlabs/tensorcom using basic k8s pods and a manifest, which uses msgpack with pyzmq. I am trying to avoid using that and stick with the dask mechanics, but I am tempted to mock up a quick profile script of using zmq to bypass dask entirely, but within dask tasks.

This all might not belong in xbatcher, but I wanted to put it out there to get ay feedback people might have.

ljstrnadiii commented 1 year ago

Here is an example of using the prefetch generator with tf.data.Dataset

import tensorflow as tf
from xbatcher.prefetch_generators import PrefetchBatchGenerator

# let array be chunked only along first dim
array = ...
batch_size=128

def do_tf_ml():
    batch_gen = lambda : PrefetchBatchGenerator(array=array, batch_size=batch_size, prefetch=20)
    ds_counter = tf.data.Dataset.from_generator(batch_gen, output_types=tf.int32, output_shapes=(array[:batch_size].shape))
    nbytes = 0
    t0 = time.time()
    for count_batch in ds_counter.repeat().take(128):
        nbytes += count_batch.numpy().nbytes
    elapsed = time.time() - t0
    return nbytes / elapsed

f = client.submit(do_tf_ml)
f.result() / 1e9
cmdupuis3 commented 1 year ago

Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.

ljstrnadiii commented 1 year ago

Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.

@cmdupuis3 I am not sure I understand what you are asking by wrapped as a function. Do you mean be able to submit to dask?

The BatchGenerator should be available on this branch if you check it out and install in editable mode.

cmdupuis3 commented 1 year ago

Actually I think I was confused. I read if __name__ == "__main__" as the entry point rather than as a conditional entry point. Pythonisms are not for forte lol