pangeo-data / distributed-array-examples

12 stars 0 forks source link

Transformed Eulerian Mean #3

Open dcherian opened 1 year ago

dcherian commented 1 year ago

This used to be such a pain, people were looping over a years worth of files for 60 years of data.

This is ERA5

import dask.array
import flox.xarray

dims = ("time", "level", "lat", "lon")
# nyears is number of years, adjust to make bigger, 
# full dataset is 60-ish years.
nyears = 20
shape = (nyears * 365 * 24, 37, 721, 1440)
chunks = (24, 1, -1, -1)

ds = xr.Dataset(
    {
        "U": (dims, dask.array.random.random(shape, chunks=chunks)),
        "V": (dims, dask.array.random.random(shape, chunks=chunks)),
        "W": (dims, dask.array.random.random(shape, chunks=chunks)),
        "T": (dims, dask.array.random.random(shape, chunks=chunks)),
    },
    coords={"time": pd.date_range("2001-01-01", periods=shape[0], freq="H")},
)
zonal_means = ds.mean("lon")
anomaly = ds - zonal_means

anomaly['uv'] = anomaly.U*anomaly.V
anomaly['vt'] = anomaly.V*anomaly.T
anomaly['uw'] = anomaly.U*anomaly.W

temdiags = zonal_means.merge(anomaly[['uv','vt','uw']].mean("lon"))

# note method="blockwise" uses flox
temdiags = temdiags.resample(time="D").mean(method="blockwise")
temdiags
dcherian commented 2 months ago

Rewritten to use the weatherbench2 dataset

import xarray as xr

ds = xr.open_zarr(
    "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr",
    chunks={},
)

ds = ds[
    ["u_component_of_wind", "v_component_of_wind", "temperature", "vertical_velocity"]
].rename(
    {
        "u_component_of_wind": "U",
        "v_component_of_wind": "V",
        "temperature": "T",
        "vertical_velocity": "W",
    }
)

zonal_means = ds.mean("longitude")
anomaly = ds - zonal_means

anomaly["uv"] = anomaly.U * anomaly.V
anomaly["vt"] = anomaly.V * anomaly.T
anomaly["uw"] = anomaly.U * anomaly.W

temdiags = zonal_means.merge(anomaly[["uv", "vt", "uw"]].mean("longitude"))

# This is incredibly slow, takes a while for flox to construct the graph
# daily = temdiags.resample(time="D").mean()

# Option 2: rechunk to make it a blockwise problem
# we should do this automatically
from xarray.groupers import TimeResampler
daily = temdiags.chunk(time=TimeResampler("D")).resample(time="D").mean()

daily.to_zarr(SOMEWHERE)