openclimatefix / nwp

Tools for downloading and processing numerical weather predictions
MIT License
10 stars 3 forks source link

Don't pass the massive `xr.Dataset` between processes! #22

Open JackKelly opened 1 year ago

JackKelly commented 1 year ago

Describe the bug

The code works at the moment. But it does something that's slow, CPU-intensive, and memory-intensive: it passes massive xr.Datasets between processes.

The main process does this:

    # Run the processes!
    with multiprocessing.Pool() as pool:
        for ds in pool.imap(load_grib_files, tasks):
            append_to_zarr(ds, destination_zarr_path)

Which requires every ds to be pickled and copied from the worker process to the main process through a pipe (which is very slow - multiple seconds - for large objects).

Experiment

This minimal example takes 6 seconds to run, and uses a trivially small amount of RAM and CPU. Each worker process creates a 160 MB array. But, crucially, each worker process doesn't pass that array back to the main process:

from multiprocessing.pool import Pool
from time import sleep
import numpy as np

# task executed in a worker process
def task(identifier: int):
    # generate an 160 MByte array of random values:
    arr = np.random.rand(1000, 1000, 20)
    # report a message
    print(f'Task {identifier} executing.', flush=True)
    # block for a (random) moment
    sleep(arr[0, 0, 0])

if __name__ == '__main__':
    # create and configure the process pool
    with Pool() as pool:
        # issue tasks to the process pool
        pool.imap(task, range(50))
        # shutdown the process pool
        pool.close()
        # wait for all issued task to complete
        pool.join()

(this code is adapted from here)

If we just append the line return arr at the end of the task function (so each worker process pickles the array and attempts to send it to the main process) then the script runs for 30 seconds using max CPU, and then consumes all the RAM on my desktop before crashing!

Expected behavior

I think the fix is simple: We just tell each worker process to save the dataset. I'm as sure as I can be that imap will still guarantee that the processes run in order, even if the processes take different amounts of time to complete. UPDATE: I was wrong! imap runs tasks in arbitrary order, so we can't save to zarr in arbitrary order.

Additional context

The code used to use a "chain of locks"... but that proved unreliable and so the "chain of locks" were replaced with imap in commit 33330bfb9e40ae682023002c0267ecfeba974704. Replacing the "chain of locks" with imap was definitely the right thing to do (much simpler code; much more stable!) We just need to make sure we don't pass massive datasets between processes :slightly_smiling_face: .

JackKelly commented 1 year ago

After chatting with @jacobbieker .... it turns out I was wrong! imap runs tasks in arbitrary order!

Evidence:

from multiprocessing.pool import Pool
from time import sleep

import numpy as np

# task executed in a worker process
def task(identifier: int):
    # generate an 160 MByte array of random values:
    rng = np.random.default_rng(seed=identifier)
    arr = rng.random((1000, 1000, 20))
    sleep_time_secs = arr[0, 0, 0] * 4
    print(f'Task {identifier} sleeping for {sleep_time_secs:.3f} secs...', flush=True)
    sleep(sleep_time_secs)
    print(f'Task {identifier} DONE!', flush=True)

if __name__ == '__main__':
    # create and configure the process pool
    with Pool() as pool:
        # issue tasks to the process pool
        pool.imap(task, range(50))
        # shutdown the process pool
        pool.close()
        # wait for all issued task to complete
        pool.join()

Produces this output:

(nwp) jack@jack-NUC:~/dev/ocf/nwp/scripts$ time python test_imap.py 
Task 2 sleeping for 1.046 secs...
Task 3 sleeping for 0.343 secs...
Task 1 sleeping for 2.047 secs...
Task 5 sleeping for 3.220 secs...
Task 7 sleeping for 2.500 secs...
Task 0 sleeping for 2.548 secs...
Task 4 sleeping for 3.772 secs...
Task 6 sleeping for 2.153 secs...
Task 3 DONE!
Task 8 sleeping for 1.308 secs...
Task 2 DONE!
Task 9 sleeping for 3.481 secs...
Task 8 DONE!
Task 10 sleeping for 3.824 secs...
Task 1 DONE!
Task 11 sleeping for 0.514 secs...
Task 6 DONE!
Task 12 sleeping for 1.003 secs...
Task 7 DONE!
Task 11 DONE!
...
Task 47 sleeping for 2.967 secs...
Task 48 sleeping for 1.551 secs...
Task 46 sleeping for 3.622 secs...
Task 40 DONE!
Task 49 sleeping for 1.451 secs...
Task 43 DONE!
Task 48 DONE!
Task 42 DONE!
Task 45 DONE!
Task 41 DONE!
Task 49 DONE!
Task 47 DONE!
Task 46 DONE!

real    0m16.123s
user    0m5.739s
sys     0m3.232s
JackKelly commented 1 year ago

Two solutions spring to mind:

  1. Can we write to Zarr in arbitrary order? Maybe Zarr can do this out-of-the-box now? Or maybe we need to "lazily pre-allocate" the entire array first?
  2. Failing that, each worker process could write a netcdf file to disk, and the main process could load that netcdf file and write it to the zarr. Something like this:
    # Run the processes!
    with multiprocessing.Pool() as pool:
        for netcdf_filename in pool.imap(convert_grib_files_to_netcdf, tasks):
            append_netcdf_to_zarr(netcdf_filename, destination_zarr_path)
JackKelly commented 1 year ago

Good: Option 1 (from the comment above) sounds viable. The xarray docs suggest that we can write to the Zarr in arbitrary order and in parallel if we first create the relevant zarr metadata. Some relevant quotes from the xarray docs:

you can use region to write to limited regions of existing arrays in an existing Zarr store. This is a good option for writing data in parallel from independent processes. To scale this up to writing large datasets, the first step is creating an initial Zarr store without writing all of its array data. ... Concurrent writes with region are safe as long as they modify distinct chunks in the underlying Zarr arrays (or use an appropriate lock).

JackKelly commented 1 year ago

But, before making this change, I'll run some experiments with the code as is, to get a feel for whether this is even a problem!

JackKelly commented 1 year ago

Converting two NWP init times (using Wholesale1 & Wholesale2) takes 54 seconds on my NUC, and very almost runs out of RAM.

Downcasting the dataset to float16 before passing the dataset from the worker process to the main process speeds it up to 41 seconds. Which does hint that there's considerable overhead to passing the object between workers.

JackKelly commented 1 year ago

Not passing anything back to the main process (and hence not writing anything to disk) takes 32 seconds.

JackKelly commented 1 year ago

I've done some experiments using dataset.to_zarr(region=...)... it's looking very do-able (to write zarr chunks in arbitrary order, in parallel. After first constructing the metadata.) I think it could work something like this...

Each xr.Dataset will contain two DataArrays: The "UKV" data, and a "chunk_exists" DataArray: A 1D boolean array, with one chunk per element (so, yeah, the individual chunks will be tiny!) which just indicates which chunks have actually been written to disk completely. Why? Consider what happens if we write metadata saying we've got 1 year of NWP init times for 2022. But then the script crashes after only writing 4 arbitrary init time chunks to disk. When we re-run the script, it will see that the init_time coords extend to the end of 2022. So how will the script know that it hasn't finished converting all 2022 grib files to Zarr chunks? We could do something like ds["UKV"].isnull().max(dim=["variable", "step", "y", "x"]) but that will load all the Zarr chunks! We could write individual files to disk to indicate which chunks have been written. But it's tidier if we keep that data inside the Zarr (it should be easy to delete this data if needed).

In the main process, before launching the pool of workers:

If we have to create new metadata or update existing metadata then, in the main process:

When we actually write data to disk, we can use imap_unordered.

We can write actual chunks like this:

# The drop_vars is necessary otherwise Zarr will try to
# overwrite variable, step, y, and x coord arrays.
dataset.drop_vars(['variable', 'step', 'y', 'x']).to_zarr(
    "test_regions.zarr",
    region={"init_time": slice(10, 20)},  # integer index slice.
    )
JackKelly commented 1 year ago

On second thoughts... This isn't a priority for me. Especially if I downsample the NWPs in the worker process before passing it to the Zarr-writing process.

The next task I plan to work on is down sampling the NWPs, ready for the National PV forecasting experiments.