openclimatefix / predict_pv_yield

Using optical flow & machine learning to predict PV yield
MIT License
51 stars 11 forks source link

Use dask to load data for many batches asynchronously #27

Closed JackKelly closed 3 years ago

JackKelly commented 3 years ago

use async functionality of dask?

Three ideas to try, in increasing level of change to the codebase:

  1. Use client.compute() to asynchonously load data into memory. One thing I'm not sure about is that, in the existing code, future.result() returns a DataInMemory class. Not sure how to get dask to return a DataInMemory class? Might be possible to do future = client.compute(DataInMemory(data=selected_data))

  2. As above, but persist the data on the workers, and use something like client.submit(xr.DataArray.isel, future, init_time=0) to select data (but then we'll have to wait for the data of each example to be copied to the main process. But maybe that's OK, because each example is pretty tiny).

  3. Don't use multiple PyTorch DataLoader processes. Instead rely entirely on dask to distribute work across multiple processes.

JackKelly commented 3 years ago

Huh. It's not necessary to use async or Client(asynchronous=True)! Life gets easier without these.

JackKelly commented 3 years ago

I'm not sure but I think that attempting to use dask from multiple processes feels like a bad idea :)

So let's try option 3 above: i.e. all my Python code will 'look' single-threaded (and PyTorch will only fire up one worker) and use dask:

JackKelly commented 3 years ago

See this great example for using dask to create batches for a PyTorch model: https://examples.dask.org/machine-learning/torch-prediction.html

Don't need to use client.compute(). Instead use dask.compute(), I think. And dask.compute(*list_of_delayed).

I particularly like creating a batch in parallel (the current code creates batches one example at a time).

Might be worth a little experiment of not pre-loading data into memory, but instead loading directly from disk?

But that'll probably be too slow. To hit 50 batches per second, we need to create each batch in 20 ms.

Load all in-memory segments at once using in_mem = dask.compute(*list_of_delayed).

Or maybe do everything in one graph:

@dask.delayed
def load(data, start, end):
    selected_data = data.sel(time=slice(start, end))
    selected_data = selected_data.persist()
    return selected_data

@dask.delayed
def sample(data):
    random_slice = get_random_slice()
    return data.sel(time=random_slice)

Maybe we even create an tonne of batches like this, computing in the background with dask.submit(), then go into a training loop using previous computed batches.

JackKelly commented 3 years ago

OK, this is looking good... (it took me a little while to get my head round dask!)...

def rand_slice(lowest=0, highest=10, length=4):
    start = np.random.randint(
        low=lowest, 
        high=highest - length)
    end = start + length
    return slice(start, end)

def get_batch(slices, temperature):
    """Define dask.delayed computation graph."""
    selected_nwps = []
    for slice in slices:
        selection = temperature.isel(init_time=slice, step=0)
        selected_nwps.append(selection)

    return dask.compute(selected_nwps)

client = dask.distributed.Client()
slices = [rand_slice() for _ in range(1_000)]
temperature_rechunked = temperature.chunk({'step': 1})

# Non-blocking
future = client.submit(get_batch, slices=slices, temperature=temperature_rechunked)

# Blocks until data is ready
data = future.result()

Some findings:

JackKelly commented 3 years ago

Making good progress in experiment 23, in new function get_batches_using_dask()

Still to do:

JackKelly commented 3 years ago

Making progress... some findings:

get_batches_using_dask() takes a little over a second. That's mostly due to the sat_data.sel() function. Investigating ways to speed it up:

But this doesn't matter too much because we can run get_batches_using_dask and dask.compute() using futures.ThreadPoolExecutor. So all this stuff can run in the background while PyTorch is training the current batch on the GPU.

And, if worse comes to worst then we can train on each sample a few times, by mixing up the batches, to give the 'data loading' thread more time to finish loading the next batch.

JackKelly commented 3 years ago

Some thoughts on passing PV data into dask graph:

Also, for sat data, try pre-selecting date ranges for each segment before looping to create each delayed sample. See if that speeds things up.

JackKelly commented 3 years ago

Full pipeline with just satellite data kind of works (and I deleted a pleasingly large amount of code!)

Gets > 60 it/s, if we loop round each load 4 times.

BUT! Training still pauses while the main process is setting up the processing graph etc.

~I'm now wondering about having two processes:~

~1. Data loading process. Owns the dask Client and all the data objects. Loads data into two 'superbatches' in CPU memory (the 'current' one and the 'next' one).~ ~2. Training process. Shares memory (the two superbatches with the data loading process). Sends a message to the data loading process to say "start replacing superbatch 1 or 2". (but can we share dicts? I guess we need to share the numpy arrays? Maybe try not sharing memory, and instead pickling the dicts?!? Might be slow though.) Pin the entire superbatch at once.~

UPDATE: Use PyTorch to spin up the worker process by setting num_workers=1

JackKelly commented 3 years ago

Ways to run data loading in separate process:

JackKelly commented 3 years ago

TODO: Sort nwp.init_time! Update: DONE!

JackKelly commented 3 years ago

Looking good! Using the simplest approach (suggested two comments above). Getting >80 it/s with NWPs and >130 it/s without NWPs (after 35,000 iterations); 90 it/s with NWPs after 1 million iterations. (Probably need to re-create NWP Zarr so it's faster to load #26)

Pauses for a few seconds whilst loading more data, though, which I think is because dask swamps the processes' CPU whilst loading.

Try:

JackKelly commented 3 years ago

OK, separate process is looking good. Getting 100 it/s with NWPs; although that requires looping round each super-batch about 13 times.

It's possible that using a Pipe is faster than a Queue. But, of course, if we use a Queue then we could have multiple worker processes loading data, which might be nice because a single worker process rarely hits more than 200 MB/s; and I've seen multiple-processes hit 800 MB/s.

JackKelly commented 3 years ago

Queue seems to be a lot slower than Pipe. That's confirmed by this thread: https://stackoverflow.com/questions/8463008/multiprocessing-pipe-vs-queue

JackKelly commented 3 years ago

multiprocessing.Manager().Queue() is supposed to be faster than multiprocessing.Queue, but it blows up:

multiprocessing.managers.RemoteError: 
---------------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/managers.py", line 243, in serve_client
    request = recv()
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/connection.py", line 251, in recv
    return _ForkingPickler.loads(buf.getbuffer())
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
    fd = df.detach()
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/resource_sharer.py", line 58, in detach
    return reduction.recv_handle(conn)
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/reduction.py", line 189, in recv_handle
    return recvfds(s, 1)[0]
  File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/reduction.py", line 164, in recvfds
    raise RuntimeError('received %d items of ancdata' %
RuntimeError: received 0 items of ancdata
JackKelly commented 3 years ago

So, I think the Pipe version is best, and isn't too much trouble to use a Lock to make sure only one process writes at once!

So, I think we might finally be there!

Using num_workers=0 for PyTorch.

And using our own mutliple processes to load data. Each process uses Dask. And puts data into a Pipe.

If a Pipe() is a bit weird, because it doesn't queue data (obvs!), then could put the Pipe data into a List in the main process?

JackKelly commented 3 years ago

Next: Clean up code & get PV data in there (maybe by sharing a numpy array between worker processes???)

JackKelly commented 3 years ago

Implemented random shuffling of samples in memory, whilst waiting for data to load from disk. Surprisingly (in a good way), this has sped up training to about 109 it/s!

JackKelly commented 3 years ago

TODO: If one process dies, then kill all the others

JackKelly commented 3 years ago

~TODO: Implement our own Queue with Pipe, List and Lock.~ UPDATE: Actually, using multiprossing.Pipe is fine!

JackKelly commented 3 years ago

So. This does work. And, on some metrics, it performs well (102 it/s; only 34 GB of RAM)

But, a central problem is that Dask actually takes a really long time (over a minute?) to create its graph of computation.

So I'm going to try an alternative approach, which I'll document in a new issue.