leap-stc / data-and-compute-team

Repo to organize issues/mangagment of the LEAP Data and Computation Team
Apache License 2.0
1 stars 0 forks source link

Understanding dask and batching (async discussion) #13

Open SammyAgrawal opened 2 months ago

SammyAgrawal commented 2 months ago

Wanted to open a thread to inquire about best practices regarding dask chunking.

Screenshot 2024-08-09 at 3 58 25 PM

Ok imagine you have ingested some dataset that is over 100GB, so definitely not fitting into memory. You want to train an ML model using this dataset.

Are there any dask optimizations for this process?

Ran a simple test:

import time
import random
times = []
batches = []
random.randint(0, 10)
sizes = [1, 2, 4, 8, 16, 32, 64, 128]
for sz in sizes:
    start = random.randint(0, 5000)
    now = time.time()
    batch = ds_out.isel(time=slice(start, start+sz))
    batch.load()
    print(batch.dims)
    times.append(time.time() - now)
    batches.append(batch)

Was surprised by the fact that batch size seemingly had no effect on load time.

Screenshot 2024-08-09 at 4 00 47 PM
SammyAgrawal commented 2 months ago

Questions top of mind:

jbusecke commented 2 months ago

Couple of comments:

This also means you are not using any paralellism (try to record the CPU useage while loading, I bet it never exceeds 100%).

Finally there might be some caching going on here, which could explain the fluctuations in the load time, even though these might also be random. Bottomline you should use bigger batches! Insert jaws meme

it seems if you iterated over the dataset and did this, eventually you would have loaded everything and kernel will crash. Can you "unload" such that once a batch is processed, you garbage collect it? If you overwrite the "batch" variable will it be automatically garbage collected and the memory will be freed?

I think as long as you overwrite the object you are good and the old data will be garbage collected.

Does loading in line with existing chunk dimensions matter? I.e. does the "start" affect load times if you try to load across chunk lines?

what matters most (I think) is how many of the chunks you have to load initially. If you cross chunk boundaries you will load all the chunks into memory that you touch.

If you use multiprocessing and spawn multiple processes, how does Dask handle loading across processes? Data balancing across N processes?

This might be a good read.