coecms / xmhw

Xarray version of Marine Heatwaves code by Eric Olivier
https://xmhw.readthedocs.io/en/latest/
Apache License 2.0
21 stars 10 forks source link

dask performance #65

Open florianboergel opened 1 year ago

florianboergel commented 1 year ago

Again, it must be verified, but the dask documentation says one should avoid calling multiple delayed functions. That is why the @dask_delayed is removed in calc_clim(). Instead, I only apply this now for calc_tresh and calc_seas(), which I know call independently.


# Loop over each cell to calculate climatologies, main functions
    # are delayed, so loop is automatically run in parallel

        # for c in ts.cell:
        #     climls.append(
        #         calc_clim(
        #             ts.sel(cell=c),
        #             tdim,
        #             pctile,
        #             windowHalfWidth,
        #             smoothPercentile,
        #             smoothPercentileWidth,
        #             tstep,
        #             skipna,
        #         )
        #     )

        thresClimYearDelayed = []
        seasClimYearDelayed = []

        for c in ts.cell:
            thresClimYearDelayed.append(calculate_thresh(ts.sel(cell=c), pctile, skipna,
                                                         tstep, windowHalfWidth, tdim,
                                                         smoothPercentile, smoothPercentileWidth))
            seasClimYearDelayed.append(calculate_seas(ts.sel(cell=c), skipna, tstep, windowHalfWidth, tdim,
                                                      smoothPercentile, smoothPercentileWidth))

    thresClimYear = dask.compute(*thresClimYearDelayed)
    seasClimYear = dask.compute(*seasClimYearDelayed)

    results = [thresClimYear, seasClimYear]

The different structure of results needs to be accounted for below.

To make sure I only call one delayed function I also removed the dask_delayed tag for runavg so that calculate_seas looks like.

@dask.delayed(nout=1)
def calculate_thresh(ts, pctile, skipna, tstep, windowHalfWidth, tdim, smoothPercentile, smoothPercentileWidth):
    """Calculate threshold for one cell grid at the time

    Parameters
    ----------
    twindow: xarray DataArray
        Stacked array timeseries with new 'z' dimension representing
        a window of width 2*w+1
    pctile: int
        Threshold percentile used to detect events
    skipna: bool
        If True percentile and mean function will use skipna=True.
        Using skipna option is much slower

    Returns
    -------
    thresh_climYear: xarray DataArray
        Climatological threshold
    """
    twindow = window_roll(ts, windowHalfWidth, tdim)

    thresh_climYear = twindow.groupby("doy").quantile(
        pctile / 100.0, dim="z", skipna=skipna
    )
    # calculate value for 29 Feb from mean of 28-29 feb and 1 Mar
    if tstep is False:
        thresh_climYear = thresh_climYear.where(
            thresh_climYear.doy != 60, feb29(thresh_climYear)
        )

    if smoothPercentile:
        thresh_climYear = runavg(thresh_climYear, smoothPercentileWidth)

    thresh_climYear = thresh_climYear.chunk({"doy": -1})
    return thresh_climYear

If you think this make sense, I can also make a pull request to verify the changes.

florianboergel commented 1 year ago

For me it improved the computation.

paolap commented 1 year ago

I'm not sure what you mean, as "calc_clim" is not delayed in my version, it's called in a loop but the delayed steps are calc_thresh, calc_seas and run_avg, as they're all called independently from each other. What dask doesn't recommend is to called a delayed function from another delayed function. So in this case if calc_clim was delayed when defined. I'm not arguing that your order of operations might have sped the computation, just it might be for other reasons, for example in your version window_roll is called inside a delayed function, so maybe delaying window_roll in the original structure might still work.

Frankly I don't remember why I didn't delayed window_roll at the time.

I will try that when I have time, frankly it's hard for me to test reliably performance as the system we use in unreliable in that sense.