Closed JackKelly closed 3 years ago
Huh. It's not necessary to use async
or Client(asynchronous=True)
! Life gets easier without these.
I'm not sure but I think that attempting to use dask from multiple processes feels like a bad idea :)
So let's try option 3 above: i.e. all my Python code will 'look' single-threaded (and PyTorch will only fire up one worker) and use dask:
DataInMemory
class and subclasses, as Dask should be able to look after that for us. But not sure how! Maybe just have AsyncDataLoader
classes which keep the futures
, and then uses dask
to sample from the futures
?See this great example for using dask to create batches for a PyTorch model: https://examples.dask.org/machine-learning/torch-prediction.html
Don't need to use client.compute()
. Instead use dask.compute()
, I think. And dask.compute(*list_of_delayed)
.
I particularly like creating a batch in parallel (the current code creates batches one example at a time).
Might be worth a little experiment of not pre-loading data into memory, but instead loading directly from disk?
But that'll probably be too slow. To hit 50 batches per second, we need to create each batch in 20 ms.
Load all in-memory segments at once using in_mem = dask.compute(*list_of_delayed)
.
Or maybe do everything in one graph:
@dask.delayed
def load(data, start, end):
selected_data = data.sel(time=slice(start, end))
selected_data = selected_data.persist()
return selected_data
@dask.delayed
def sample(data):
random_slice = get_random_slice()
return data.sel(time=random_slice)
Maybe we even create an tonne of batches like this, computing in the background with dask.submit()
, then go into a training loop using previous computed batches.
OK, this is looking good... (it took me a little while to get my head round dask!)...
def rand_slice(lowest=0, highest=10, length=4):
start = np.random.randint(
low=lowest,
high=highest - length)
end = start + length
return slice(start, end)
def get_batch(slices, temperature):
"""Define dask.delayed computation graph."""
selected_nwps = []
for slice in slices:
selection = temperature.isel(init_time=slice, step=0)
selected_nwps.append(selection)
return dask.compute(selected_nwps)
client = dask.distributed.Client()
slices = [rand_slice() for _ in range(1_000)]
temperature_rechunked = temperature.chunk({'step': 1})
# Non-blocking
future = client.submit(get_batch, slices=slices, temperature=temperature_rechunked)
# Blocks until data is ready
data = future.result()
Some findings:
get_batch
to client.submit
and waiting for future.result()
takes almost exactly as long as directly calling get_batch()
(this is good. This suggests that calling client.submit
on get_batch
is the correct way to asynchronously send a dask.delayed
graph to dask's worker processes.temperature
dataset, then the code runs much faster than if we sample from a wide range. This suggests that the code is correctly figuring out that it only has to load data from disk once. Making good progress in experiment 23, in new function get_batches_using_dask()
Still to do:
DataArray.values
(my custom ToTensor()
calls values
). Perhaps DataArray.data
is the correct approach? Seems to work? UPDATE: torch.utils.data._utils.collate.default_collate()
converts to Tensors for us! Yay!zarr_chunk_sequences
(started this in the new get_random_segments
func)zarr_chunk_sequence
Making progress... some findings:
get_batches_using_dask()
takes a little over a second. That's mostly due to the sat_data.sel()
function. Investigating ways to speed it up:
sat_data.sel(time=slice(start, end))
and sat_data[:100]
take about 700 µs each.sat_data.data[:100]
takes 200 µs. (But doesn't come with the xarray metadata, of course. But that might not be too much of an issue, depending on the transforms
.)sat_data.data[[slice(0, 100), slice(50, 150)]]
doesn't work)map()
doesn't help. (Well, map
returns very quickly with a generator. But we still need to realise that generator.)But this doesn't matter too much because we can run get_batches_using_dask
and dask.compute()
using futures.ThreadPoolExecutor
. So all this stuff can run in the background while PyTorch is training the current batch on the GPU.
And, if worse comes to worst then we can train on each sample a few times, by mixing up the batches, to give the 'data loading' thread more time to finish loading the next batch.
Some thoughts on passing PV data into dask graph:
Also, for sat data, try pre-selecting date ranges for each segment before looping to create each delayed sample. See if that speeds things up.
Full pipeline with just satellite data kind of works (and I deleted a pleasingly large amount of code!)
Gets > 60 it/s, if we loop round each load 4 times.
BUT! Training still pauses while the main process is setting up the processing graph etc.
~I'm now wondering about having two processes:~
~1. Data loading process. Owns the dask Client and all the data objects. Loads data into two 'superbatches' in CPU memory (the 'current' one and the 'next' one).~ ~2. Training process. Shares memory (the two superbatches with the data loading process). Sends a message to the data loading process to say "start replacing superbatch 1 or 2". (but can we share dicts? I guess we need to share the numpy arrays? Maybe try not sharing memory, and instead pickling the dicts?!? Might be slow though.) Pin the entire superbatch at once.~
UPDATE: Use PyTorch to spin up the worker process by setting num_workers=1
Ways to run data loading in separate process:
num_workers=1
. Worker process instantiates dask Client & Zarr files etc. Probably needs a high prefetch_factor
because the worker process will be slow to return batches when it's loading data from disk, so the main process needs to cache lots of batches.num_workers=1
. Is it possible to throw a complete 'superbatch' over the wall to DataLoader, and then have DataLoader sample from the superbatch? I'm not sure it is (without writing our own training loop to run on the GPU)num_workers=0
. The DataSet spawns a new process every time it wants to load a new superbatch (this means we don't have to have complicated logic for killing the process). No shared memory, just copy the data between processes using a queue of length 1 or something like that.num_workers=0
. DataSet spawns a new process which runs continually and pushes new super batches into a Queue. Main process then just has to serve samples from that superbatch. Need to think through how to send a message to cleanly kill the worker process, so it shutsdown the dask Client and Zarr stores.TODO: Sort nwp.init_time
! Update: DONE!
Looking good! Using the simplest approach (suggested two comments above). Getting >80 it/s with NWPs and >130 it/s without NWPs (after 35,000 iterations); 90 it/s with NWPs after 1 million iterations. (Probably need to re-create NWP Zarr so it's faster to load #26)
Pauses for a few seconds whilst loading more data, though, which I think is because dask swamps the processes' CPU whilst loading.
Try:
persistent_workers=True
. Conclusion: This partially works! When one worker throws a StopIteration
, PyTorch stops caring about that worker process, and only pulls from the remaining worker processes (which is good). BUT! When the next 'epoch' starts, PyTorch appears to not start training until all the workers are ready.OK, separate process is looking good. Getting 100 it/s with NWPs; although that requires looping round each super-batch about 13 times.
It's possible that using a Pipe
is faster than a Queue
. But, of course, if we use a Queue then we could have multiple worker processes loading data, which might be nice because a single worker process rarely hits more than 200 MB/s; and I've seen multiple-processes hit 800 MB/s.
Queue seems to be a lot slower than Pipe. That's confirmed by this thread: https://stackoverflow.com/questions/8463008/multiprocessing-pipe-vs-queue
multiprocessing.Manager().Queue() is supposed to be faster than multiprocessing.Queue, but it blows up:
multiprocessing.managers.RemoteError:
---------------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/managers.py", line 243, in serve_client
request = recv()
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/connection.py", line 251, in recv
return _ForkingPickler.loads(buf.getbuffer())
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
fd = df.detach()
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/resource_sharer.py", line 58, in detach
return reduction.recv_handle(conn)
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/reduction.py", line 189, in recv_handle
return recvfds(s, 1)[0]
File "/home/jack/miniconda3/envs/predict_pv_yield/lib/python3.8/multiprocessing/reduction.py", line 164, in recvfds
raise RuntimeError('received %d items of ancdata' %
RuntimeError: received 0 items of ancdata
So, I think the Pipe
version is best, and isn't too much trouble to use a Lock to make sure only one process writes at once!
So, I think we might finally be there!
Using num_workers=0 for PyTorch.
And using our own mutliple processes to load data. Each process uses Dask. And puts data into a Pipe.
If a Pipe() is a bit weird, because it doesn't queue data (obvs!), then could put the Pipe data into a List in the main process?
Next: Clean up code & get PV data in there (maybe by sharing a numpy array between worker processes???)
Implemented random shuffling of samples in memory, whilst waiting for data to load from disk. Surprisingly (in a good way), this has sped up training to about 109 it/s!
TODO: If one process dies, then kill all the others
~TODO: Implement our own Queue with Pipe, List and Lock.~ UPDATE: Actually, using multiprossing.Pipe is fine!
So. This does work. And, on some metrics, it performs well (102 it/s; only 34 GB of RAM)
But, a central problem is that Dask actually takes a really long time (over a minute?) to create its graph of computation.
So I'm going to try an alternative approach, which I'll document in a new issue.
use
async
functionality of dask?Three ideas to try, in increasing level of change to the codebase:
Use
client.compute()
to asynchonously load data into memory. One thing I'm not sure about is that, in the existing code,future.result()
returns aDataInMemory
class. Not sure how to get dask to return aDataInMemory
class? Might be possible to dofuture = client.compute(DataInMemory(data=selected_data))
As above, but
persist
the data on the workers, and use something likeclient.submit(xr.DataArray.isel, future, init_time=0)
to select data (but then we'll have to wait for the data of each example to be copied to the main process. But maybe that's OK, because each example is pretty tiny).Don't use multiple PyTorch
DataLoader
processes. Instead rely entirely on dask to distribute work across multiple processes.