pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.61k stars 1.08k forks source link

Very slow coordinate assignment with dask array #2867

Closed jbusecke closed 4 years ago

jbusecke commented 5 years ago

I am trying to reconstruct vertical cell depth from a z-star ocean model. This involves a few operations involving both dimensions and coordinates of a dataset like this:

<xarray.Dataset>
Dimensions:                              (nv: 2, run: 2, st_edges_ocean: 51, st_ocean: 50, st_ocean_sub02: 10, sw_edges_ocean: 51, sw_ocean: 50, time: 240, xt_ocean: 360, xu_ocean: 360, yt_ocean: 200, yu_ocean: 200)
Coordinates:
    area_e                               (yt_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    area_n                               (yu_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    area_t                               (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    area_u                               (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dxt                                  (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dxtn                                 (yu_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dxu                                  (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dyt                                  (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dyte                                 (yt_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    dyu                                  (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolat_c                             (yu_ocean, xu_ocean) float32 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolat_e                             (yt_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolat_n                             (yu_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolat_t                             (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolat_u                             (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolon_c                             (yu_ocean, xu_ocean) float32 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolon_e                             (yt_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolon_n                             (yu_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolon_t                             (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    geolon_u                             (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    ht                                   (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    kmt                                  (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
  * nv                                   (nv) float64 1.0 2.0
  * run                                  (run) object 'control' 'forced'
  * st_edges_ocean                       (st_edges_ocean) float64 0.0 ... 5.5e+03
  * st_ocean                             (st_ocean) float64 5.034 ... 5.395e+03
  * st_ocean_sub02                       (st_ocean_sub02) float64 5.034 ... 98.62
  * sw_edges_ocean                       (sw_edges_ocean) float64 5.034 ... 5.5e+03
  * sw_ocean                             (sw_ocean) float64 10.07 ... 5.5e+03
  * time                                 (time) object 2181-01-16 12:00:00 ... 2200-12-16 12:00:00
    tmask                                (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    tmask_region                         (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    umask                                (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    umask_region                         (yu_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
    wet_t                                (yt_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 360)>
  * xt_ocean                             (xt_ocean) float64 -279.5 ... 79.5
  * xu_ocean                             (xu_ocean) float64 -279.0 ... 80.0
  * yt_ocean                             (yt_ocean) float64 -81.5 -80.5 ... 89.5
  * yu_ocean                             (yu_ocean) float64 -81.0 -80.0 ... 90.0
    dst                                  (st_ocean, yt_ocean, xt_ocean) float64 dask.array<shape=(50, 200, 360), chunksize=(50, 200, 360)>
    dswt                                 (sw_ocean, yt_ocean, xt_ocean) float64 dask.array<shape=(50, 200, 360), chunksize=(50, 200, 360)>
    dxte                                 (yt_ocean, xu_ocean) float64 dask.array<shape=(200, 360), chunksize=(200, 359)>
    dytn                                 (yu_ocean, xt_ocean) float64 dask.array<shape=(200, 360), chunksize=(199, 360)>

The problematic step is when I assign the calculated dask.arrays to the original dataset. This happens in a function like this.

def add_vertical_spacing(ds):
    grid = Grid(ds)
    ds.coords['dst'] = calculate_ds(ds, dim='st')
    ds.coords['dswt'] = calculate_ds(ds, dim='sw')
    ds.coords['dzt'] = calculate_dz(ds['eta_t'], ds['ht'], ds['dst'])
    ds.coords['dzu'] = grid.min(grid.min(ds['dzt'], 'X'),'Y')
    return ds

This takes very long compared to a version where I assign the values as data variables:

def add_vertical_spacing(ds):
    grid = Grid(ds)
    ds.coords['dst'] = calculate_ds(ds, dim='st')
    ds.coords['dswt'] = calculate_ds(ds, dim='sw')
    ds['dzt'] = calculate_dz(ds['eta_t'], ds['ht'], ds['dst'])
    ds['dzu'] = grid.min(grid.min(ds['dzt'], 'X'),'Y')
    return ds

I am not able to reproduce this problem in a smaller example yet and realize that my example is quite complex (e.g. has functions that are not shown). But I suspect that something triggers the computation of the array, when assigning a coordinate.

I have profiled my more complex code involving this function and it seems like there is a substantial increase in calls to {method 'acquire' of '_thread.lock' objects}.

Profile output of the first version (assigning coordinates)

27662983 function calls (26798524 primitive calls) in 71.940 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 268632 46.914 0.000 46.914 0.000 {method 'acquire' of '_thread.lock' objects} 438 4.296 0.010 4.296 0.010 {method 'read' of '_io.BufferedReader' objects} 76883 1.909 0.000 1.939 0.000 local.py:240(release_data) 144 1.489 0.010 4.519 0.031 rechunk.py:514(_compute_rechunk) ...

For the second version (assigning data variables)

12928834 function calls (12489174 primitive calls) in 16.554 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 438 3.841 0.009 3.841 0.009 {method 'read' of '_io.BufferedReader' objects} 9492 3.675 0.000 3.675 0.000 {method 'acquire' of '_thread.lock' objects} 144 1.673 0.012 4.712 0.033 rechunk.py:514(_compute_rechunk) ...

Does anyone have a feel for why this could happen or how I could refine my testing to get to the bottom of this?

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.6.7 | packaged by conda-forge | (default, Feb 28 2019, 09:07:38) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 2.6.32-696.30.1.el6.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US LOCALE: en_US.ISO8859-1 libhdf5: 1.10.4 libnetcdf: 4.6.2 xarray: 0.12.0 pandas: 0.24.2 numpy: 1.16.2 scipy: 1.2.1 netCDF4: 1.5.0.1 pydap: None h5netcdf: None h5py: None Nio: None zarr: 2.3.1 cftime: 1.0.3.4 nc_time_axis: 1.2.0 PseudonetCDF: None rasterio: None cfgrib: None iris: None bottleneck: None dask: 1.1.5 distributed: 1.26.1 matplotlib: 3.0.3 cartopy: 0.17.0 seaborn: 0.9.0 setuptools: 40.8.0 pip: 19.0.3 conda: None pytest: 4.4.0 IPython: 7.1.1 sphinx: None
shoyer commented 5 years ago

But I suspect that something triggers the computation of the array, when assigning a coordinate.

This sounds correct to me.

I'm a little surprised about a significant performance between these two ways of adding coordinates. The implementation of these methods does differ slightly, but they should be pretty similar.

jbusecke commented 5 years ago

Could you think of a way I would be able to diagnose this further? Sorry for these wide questions but I am not very familiar with these xarray internals.

jbusecke commented 4 years ago

I believe this was fixed in a recent version. Closing

dcherian commented 4 years ago

It may be. It would be good to check if you have the time.

jbusecke commented 4 years ago

I think this issue was actually a dupe. I remember you pointing me to changes in 14.x, that improved the performance, but I cant find the other issue right now. I will have an opportunity to test this in the coming days on some huge GFDL data

jbusecke commented 4 years ago

I can confirm that this issue is resolved for my project. Seems to not make a difference in speed anymore whether I assign the dataarray as coordinate or data variable. Thanks for the fix!

dcherian commented 4 years ago

Great! thanks for checking.