openclimatefix / ocf_datapipes

OCF's DataPipe based dataloader for training and inference
MIT License
13 stars 11 forks source link

Add channel diffing #314

Closed dfulu closed 1 month ago

dfulu commented 1 month ago

Some of the NWP channels are accumulations. For example, in the ECMWF data the channels ['dswrf', 'dlwrf', 'sr', 'duvrs'] are all accumulated. This leads to issues in production if the NWP data is more delayed than the model has seen in production since the values get higher and higher throughout the forecast steps.

Therefore we need to apply a diff for this accumulation data.

This merge includes

This pull request adds the diff to the time slicing function. This is a slightly complex way of doing things but the simpler way of diffing the data when initially loaded performed very poorly.

e.g. In this simple case the diffed data talkes 5.5 times longer than the non-diffed data. I don't know why the slowdown should be so much. More chunks in the step-wise dimension need to be loaded for each slice but nothing close to 5.5 times more. I wonder if this could be due to more dask overhead and a more complex graph

import xarray as xr
import ocf_blosc2
import numpy as np
import pandas as pd
from datetime import timedelta

ds_ecmwf = xr.open_zarr("/mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr")
accum_channels = ['dswrf', 'dlwrf', 'sr', 'duvrs']

# Diff some of the channels and recombine
xr_accum = ds_ecmwf.sel(variable=accum_channels).diff(dim="step", label="lower")
xr_non_accum = (
    ds_ecmwf.sel(variable=[v for v in ds_ecmwf.variable.values if v not in accum_channels])
    .isel(step=slice(0, -1))
)
ds_diffed = xr.concat([xr_accum, xr_non_accum], dim="variable")

def random_sel(ds):
    """Function for testing speed of diffed and non-diffed data"""
    lat0 = np.random.randint(0, len(ds.latitude)-24)
    lon0 = np.random.randint(0, len(ds.longitude)-24)
    step0 = np.random.randint(0, len(ds.step)-16)
    init_time = np.random.randint(0, len(ds.init_time))
    return ds.isel(
        latitude=slice(lat0, 24+lat0), 
        longitude=slice(lon0, 24+lon0), 
        step=slice(step0, 8+step0),
        init_time=init_time
    ).compute()
%timeit -n 3 -r 3 random_sel(ds_ecmwf)
>> 1.48 s ± 420 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
%timeit -n 3 -r 3 random_sel(ds_diffed)
>> 8.16 s ± 1.45 s per loop (mean ± std. dev. of 3 runs, 3 loops each)

In the version submitted in this pull request, using the channel diffing (which is optional) causes a slowdown of less than 2x.

peterdudfield commented 1 month ago

286 does this solve this @dfulu?

dfulu commented 1 month ago

286 does this solve this @dfulu?

Yeh, this closes #286 except for the normalisation part. I'll keep this open for now until I can add that