pangeo-data / xESMF

Universal Regridder for Geospatial Data
http://xesmf.readthedocs.io/
MIT License
183 stars 32 forks source link

xe.Regridder does not work when parallel is set to true #299

Closed mgorfer closed 9 months ago

mgorfer commented 9 months ago

I have an input xarray dataset I want to regrid

<xarray.Dataset>
Dimensions:    (longitude: 1440, latitude: 721, time: 132)
Coordinates:
  * longitude  (longitude) float32 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * time       (time) datetime64[ns] 2000-01-01 2000-02-01 ... 2010-12-01
Data variables:
    d2m        (time, latitude, longitude) float32 dask.array<chunksize=(12, 721, 1440), meta=np.ndarray>

First I create the grid and then the regridder:

    ds_out = xe.util.grid_global(2.5, 2.5)
    regridder = xe.Regridder(
        ds,
        ds_out,
        "bilinear",
        periodic=True,
        parallel=True,
    )

If I have parallel set to true, I get an error message:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File <timed exec>:1

File ~/python_file.py:41, in regrid_to_lat_lon(ds, target_res_lat, target_res_lon)
     23 """
     24 Regrid any grid to a res_lat x res_lon degree lat lon grid.
     25 
   (...)
     38     Regridded dataset.
     39 """
     40 ds_out = xe.util.grid_global(target_res_lat, target_res_lon)
---> 41 regridder = xe.Regridder(
     42     ds,
     43     ds_out,
     44     "bilinear",
     45     periodic=True,
     46     parallel=True,
     47 )
     48 dr_out = regridder(ds)

File /[...]/python3.9/site-packages/xesmf/frontend.py:955, in Regridder.__init__(self, ds_in, ds_out, method, locstream_in, locstream_out, periodic, parallel, **kwargs)
    952     self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out}
    954 if parallel:
--> 955     self._init_para_regrid(ds_in, ds_out, kwargs)

File /[...]/python3.9/site-packages/xesmf/frontend.py:973, in Regridder._init_para_regrid(self, ds_in, ds_out, kwargs)
    971     ds_out['mask'] = mask
    972 else:
--> 973     ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims])
    974     ds_out = ds_out.coords.to_dataset()
    975     mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks)

File /[...]/python3.9/site-packages/xesmf/frontend.py:973, in <listcomp>(.0)
    971     ds_out['mask'] = mask
    972 else:
--> 973     ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims])
    974     ds_out = ds_out.coords.to_dataset()
    975     mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks)

File /[...]/python3.9/site-packages/xarray/core/utils.py:455, in Frozen.__getitem__(self, key)
    454 def __getitem__(self, key: K) -> V:
--> 455     return self.mapping[key]

KeyError: 'y'

When I set parallel to false, everything works and my regridder looks like this

xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_721x1440_72x144_peri.nc 
Reuse pre-computed weights? False 
Input grid shape:           (721, 1440) 
Output grid shape:          (72, 144) 
Periodic in longitude?      True

I am using Python 9.9.18, Xarray 2023.8.0, and xesmf 0.8.1.

huard commented 9 months ago

Thanks for the detailed report. This is a brand new feature, we'll look into this and release a fix.

aulemahal commented 9 months ago

This is a confusing caveat of parallel=True and we should raise a better error message.

As the documentation of parallel says (empasis mine):

weights are computed in parallel with Dask on subsets of the output grid using chunks of the output grid.

ds_out, the output of grid_global has no variable and thus no chunks (the error is that the dimension name isn't a key of the chunksizes dictionary). Therefore, xESMF doesn't know how to parallelize the weights generation.

An easy way out is:

ds_out = ds_out.assign(mask = ds_out.lon.notnull() & ds_out.lat.notnull()).chunk(y=NNN, x=MMM)

Where you choose NNN and MMM according to your situation.

I'll open a PR that adds a meaningful error.

mgorfer commented 9 months ago

I am not sure if raising an error message is an appropriate solution for the usage of xe.util.grid_global. As xe.util.grid_global is part of xESMF I would expect it to work automatically for parallell=True without the user needing to interact with this global grid object manually again by adding data variables with chunks. Maybe by just adding the (chunked) mask data variable to ds_out in the regridding process when needed?

But I followed your approach, which works perfectly. I am now using:

ds_out = xe.util.grid_global(target_res_lat, target_res_lon, cf=True)
ds_out = ds_out.assign(mask=ds_out.lat.notnull() & ds_out.lon.notnull()).chunk()

Are there recommended chunk sizes for the mask data variable of ds_out?

Three further additions:

  1. I tried using ds_out with the chunked mask data variable for regridding a ds without chunks. It works normally. Also, when parallel is set to either true or false. I also did not see any speed differences. Is there any danger of just using it always like this?
  2. There is a spurious latitude_longitude float64 nan coordinate in ds_out. Is it necessary?
  3. Sometimes I get an error message [WARNING] yaksa: 10 leaked handle pool objects when using parallel=True.
aulemahal commented 9 months ago

Well you are right, but I don't think parallel=True really is one of those "user-doesn't-need-to-know" settings. At least, it doesn't feel ready for that yet... Mostly, because it depends on a user-provided chunking, which itself depends on the available RAM and other specs of the machine. But, this is debatable. It would be easy to add arguments output_chunks and mask to grid_global, would that make sense to you ?

Speed

I did

import xarray as xr
import xesmf as xe

ds = xr.tutorial.open_dataset('air_temperature')
ds_out = xe.util.grid_global(2.5, 2.5)
ds_out = ds_out.assign(mask = ds_out.lon.notnull() & ds_out.lat.notnull())

# Chunked, parallel
reg = xe.Regridder(ds, ds_out.chunk(), 'bilinear', periodic=True, parallel=True)
# took 2.94 s

# Chunked, not parallel
reg = xe.Regridder(ds, ds_out.chunk(), 'bilinear', periodic=True, parallel=False)
# took 476 ms

# Not chunked, not parallel
reg = xe.Regridder(ds, ds_out, 'bilinear', periodic=True, parallel=False)
# took 241 ms

I don't have the same input ds as you, nor the same machine, but I see here that parallelizing the weights generation is much slower than doing it the conventional way. This is what I expected because of how the parallelization is implemented. To circumvent issues in ESMF, the parallelization is implemented at a very high level, which adds a lot of overhead. This goes back to the beginning of this comment : parallel=True is useful when the alternative is impossible and we are able to chunk the output wisely so that the process fits in the available RAM. To answer your question: no, I don't think parallel should be set to true "just in case". Many regridding problems will not gain anything from it.

As a rule of thumb, I would say parallel=True is useful for upscalig, when the output grid is very large, larger than the input. The current implementation only parallelizes over the chunks of the output and thus each output chunk needs to "see" the full input grid, which is not efficient when that input becomes large. Parallelizing over the input chunks is muvch more complex problem, so it was set aside for the moment. We are evidently open to contributions to improve the implementation!

latitude_longitude

Well, this variable is how the CF conventions mark the grid as a Lat/lon one (PlateCarre, EPSG:4326). It is indeed unnecessary. It was added there for testing purposes, but we can remove it I think.

Yaksa warnings

I do too and I don't know where they come from...

mgorfer commented 9 months ago

Thank you for your explanation!

I have to admit, I do not know a lot about how to optimize my functions using chunks yet. But the default way of opening a folder full of netcdf files in xarray, open_mfdataset, creates a dataset with automatically selected chunks. Therefore, until you manually load that dataset, it is the default.

And until this loading, everything runs using dask arrays in the background and is hopefully optimized for parallel computation. (That at least is what I was hoping to accomplish by not loading my datasets before saving the final result again to a netcdf file.) All the xarray functions just work with dask and without dask, and there is no need to think about the RAM and hardware specs or to add custom arguments for either case.

Because I do not really know what is the best option, I was hoping that xESFM would just detect if the input dataset contains dask arrays or not and then chooses the most performant option for regridding them. Also, if that contains loading the dataset before regridding, or using a target grid with chunks or without chunks. Therefore, even though it would be very interesting for me to contribute here, I do not really have that experience.

I will do further speed testing with my input datset. It contains several air variables of the ERA5 monthly averaged data on pressure levels from 1940 to present, which I want to downscale from a 0.25° x 0.25° resolution to a much broader 5° x 5° resolution.

aulemahal commented 9 months ago

Just to clear some things up, xESMF uses ESMF to generate the regridding function, but it uses pure xarray to apply that function.

reg = xe.Regridder(...)  # Generating the regridding function, the regridding weights
dsout = reg(ds_in)  # Applying the function

Up until xESMF 0.7.1, we had two limitations:

In xESMF 0.8, both have been fixed, with the limitations on the second one that I highlighted above. parallel=True only applies to this second experimental feature. Even when it is False, dask is enabled when applying the function.

Thus, unless you have a very very large output grid (you are performing upscaling), xESMF 0.8 should feel as plug-and-play as the other xarray tools. Going from 0.25° to 5°, you should not need parallel=True.

mgorfer commented 9 months ago

That actually cleared up a lot of things for me! I am sorry, for not reading the available information on https://xesmf.readthedocs.io/ before starting this issue / discussion.

I thought, parallel=True was the argument which enables the lazy evaluation on Dask arrays.

mgorfer commented 9 months ago

I have done some quick speed testing for my function.

%timeit grd.Era5AhcManual()._read_data(target_res_lat_lon=2.5, xesmf_parallel=False)
24.3 s ± 401 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit grd.Era5AhcManual()._read_data(target_res_lat_lon=2.5, xesmf_parallel=True)
29.7 s ± 494 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

There is lots of other stuff going on in this function, but 5 seconds of saved time is pretty substantial anyway! So for my 2.5 × 2.5 global grid, not using parallel is (as you indicated) the faster option.