google-research / weatherbench2

A benchmark for the next generation of data-driven global weather models.
https://weatherbench2.readthedocs.io
Apache License 2.0
416 stars 42 forks source link

conservative regridding with land-sea mask #171

Open markmbaum opened 3 months ago

markmbaum commented 3 months ago

I'm wondering how the weatherbench2 pipeline handles regridding of variables that have land-sea masks, where land points are NaN. For example, when regridding sea surface temperature, conservative regridding expands sea cells into land areas. The plot below shows an example of this.

Screenshot 2024-07-04 at 5 09 37 PM

The differences are extremely small, but the land-sea pattern is very different. Is there another step in the pipeline that restores a more realistic land-sea mask in the regridded data?

The full script producing that figure is below:

from weatherbench2.evaluation import make_latitude_increasing
from weatherbench2.regridding import Grid, ConservativeRegridder
import xarray as xr
import matplotlib.pyplot as plt
import jax

jax.config.update("jax_enable_x64", True)

hi_res = xr.open_dataset(
    "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg_derived.zarr",
    engine="zarr",
)
hi_res = make_latitude_increasing(hi_res)
hi_res = hi_res["sea_surface_temperature"].sel(time="2020-01-01 00:00:00")
hi_res = hi_res.transpose("longitude", "latitude")

lo_res = xr.open_dataset(
    "gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr",
    engine="zarr",
)
lo_res = make_latitude_increasing(lo_res)
lo_res = lo_res["sea_surface_temperature"].sel(time="2020-01-01 00:00:00")
lo_res = lo_res.transpose("longitude", "latitude")

source_grid = Grid.from_degrees(lon=hi_res.longitude.values, lat=hi_res.latitude.values)
target_grid = Grid.from_degrees(lon=lo_res.longitude.values, lat=lo_res.latitude.values)

fig, axs = plt.subplots(3, 1, figsize=(6, 10))
axs = axs.flatten()

r = axs[0].pcolormesh(lo_res.values.T)
plt.colorbar(r, ax=axs[0])
axs[0].set_title("weatherbench2 240x121 dataset")

regridder = ConservativeRegridder(source_grid, target_grid)
regridded = regridder.regrid_dataset(hi_res)

r = axs[1].pcolormesh(regridded.values.T)
plt.colorbar(r, ax=axs[1])
axs[1].set_title("conservative regridding from 1440x721")

r = axs[2].pcolormesh((regridded - lo_res).values.T)
plt.colorbar(r, ax=axs[2])
axs[2].set_title("difference")

plt.tight_layout()
plt.show()