Open ljstrnadiii opened 1 year ago
In another attempt to simplify an example and profile transferring data from multiple workers to a single worker (where an ml tasks would iterate over batches) I have created this example:
from more_itertools import chunked
import dask.array as da
from distributed import Client, get_client, wait
import seaborn as sns
import pandas as pd
# 16 workers (4cores, 16gb each) on kubernetes; no special network configuration
client = Client("...")
chunk_gbps = {}
for max_gather in [.5e9, 1e9, 2e9, 4e9, 6e9]:
for chunk in [128, 256, 512, 1028, 2048]:
print(f"looking at chunk {chunk}, {max_gather / 1e9}")
_ = client.restart()
array = da.random.random((25000, 100, 100, 9), chunks=(chunk, 100, 100, 9)).persist()
del client.datasets['test']
# determine block batch size to control transfered data with gather
ex = array.blocks[0]
batch_size = max(int(np.floor(max_gather / ex.nbytes)), 1)
def compute_bytes():
client = get_client()
array = client.get_dataset('test')
blocks = list(array.blocks)
nbytes = 0
t0 = time.time()
for block_batch in chunked(blocks, batch_size):
fs = [client.compute(b) for b in block_batch]
arrays = client.gather(fs)
for array in arrays:
nbytes += array.nbytes
elapsed = time.time() - t0
return (nbytes / elapsed) / 1e9
# blocks = client.submit(get_blocks, pure=False)
f = client.submit(compute_bytes, pure=False)
chunk_gbps[(max_gather / 1e9, chunk, batch_size)] = f.result()
# plot for some trends
data = [(*k,v) for k,v in chunk_gbps.items()]
df = pd.DataFrame(data, columns=['gb_gather','chunk_size', 'actul_batch', 'gbps'])
sns.lineplot(x="gb_gather", y="gbps",hue="chunk_size",data=df)
@jhamman @maxrjones this is sort of the approach am considering developing.
I think 2gbps should be fine, but I was able to get 8+gbps with using basic k8s pods and a manifest, which uses msgpack with pyzmq. I am trying to avoid using that and stick with the dask mechanics, but I am tempted to mock up a quick profile script of using zmq to bypass dask entirely, but within dask tasks.
This all might not belong in xbatcher, but I wanted to put it out there to get ay feedback people might have.
Here is an example of using the prefetch generator with
import tensorflow as tf
from xbatcher.prefetch_generators import PrefetchBatchGenerator
# let array be chunked only along first dim
array = ...
def do_tf_ml():
batch_gen = lambda : PrefetchBatchGenerator(array=array, batch_size=batch_size, prefetch=20)
ds_counter =, output_types=tf.int32, output_shapes=(array[:batch_size].shape))
nbytes = 0
t0 = time.time()
for count_batch in ds_counter.repeat().take(128):
nbytes += count_batch.numpy().nbytes
elapsed = time.time() - t0
return nbytes / elapsed
f = client.submit(do_tf_ml)
f.result() / 1e9
Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.
Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.
@cmdupuis3 I am not sure I understand what you are asking by wrapped as a function. Do you mean be able to submit to dask?
The BatchGenerator should be available on this branch if you check it out and install in editable mode.
Actually I think I was confused. I read if __name__ == "__main__"
as the entry point rather than as a conditional entry point. Pythonisms are not for forte lol
POC Prefetch Generator:
This is a draft pr to articulate one possible approach to "prefetching" dask arrays or xarray arrays with dask.
The goals were to simultaneously:
I also tried one approach using a Queue on the workers. This felt weird and found myself reinventing features that dask already has.
Using helm to deploy a cluster on kubernetes with 8 workers (4cpu and 16gb each and relatively standard network configurations), I am able to see:
What Next?
No clue. I would like to investigate