openclimatefix / nowcasting_dataset

Prepare batches of data for training machine learning solar electricity nowcasting data
https://nowcasting-dataset.readthedocs.io/en/stable/
MIT License
25 stars 6 forks source link

Minimise time Dask spends computing graph #1

Closed JackKelly closed 3 years ago

JackKelly commented 3 years ago

Following on from openclimatefix/predict_pv_yield_2#27

JackKelly commented 3 years ago

Main process

The main process randomly shuffles the zarr_chunk_sequences and put them into a Queue.

The worker processes ingest this queue, and, for each input sequence, returns a long list of samples in another queue.

Then main process maintains a list of, say, 4 x batch_size of these long lists of samples in memory, which are rotated out one-by-one after a minimum time, once list of samples is available. And then randomly samples from samples in memory to create each batch. For each sequence, create a randomly ordered list of indices. When creating a batch, randomly pick batch_size lists (with no duplication), and pop the last indices off the lists. When one of the lists runs out of indices (i.e we've used all the samples from that list), swap that out for a new list. Need to do some experiments to see how quickly we should be swapping these out.

Worker processes

Each process gets a chunk_sequence from the queue.

For each sequence, it starts with PV (because PV is fast to load, and defines the lat, lon). The worker process finds the PV systems available for the duration of that sequence, and then randomly samples from those PV systems to create a DataFrame of sample_locations which specify the start_datetime, end_datetime, lat, lon for each sample. Here, we can also ensure there are no duplications (although worry about that later)

Using a ThreadPoolExecutor (one thread for sat data, one thread for NWPs), the worker process then loads these samples. Each DataLoader has the transforms relevant to it (e.g. the SatDataLoader has satellite transforms).

future_data = []
using futures.ThreadPoolExectutor(max_workers=len(data_loaders)) as executor:
    for data_loader in data_loaders:
        future = executor.submit(data_loader.load(sample_locations))  
        future_data.append(future)
    data = futures.wait_for(future_data)

# Then use collections.ChainMap to create the union of the list of Sample dicts, without copying data!

DataLoader.load(sample_locations)

In the first version, try using Dask, just like how we've done it already (where we give Dask the long list of samples to load, and let it figure out how to minimise duplicated effort), in the hopes that Dask will be faster if it's only considering a single data input at a time. Minimise the number of steps of processing: e.g. don't first sel a large geographical region, and then sel a small square in a later step (although maybe I should check the dask graph to see if dask can optimise this out... although, even if it can, it still requires processing to optimise that out, so keeping things simple is good). Instead, just select the small square from the get-go.

If that's still too slow, first load data from disk for sample_locations.start_datetime.min() to sample_locations.end_datetime.max(); then loop round each sample manually.

If that's still too slow, then have each DataLoader be a child process of the worker process (!).

JackKelly commented 3 years ago

Changed tack. Did #16 instead.