jbusecke / xMIP

Analysis ready CMIP6 data in python the easy way with pangeo tools.
https://cmip6-preprocessing.readthedocs.io/en/latest/?badge=latest
Apache License 2.0
199 stars 44 forks source link

Seam issue when doing conservative regriding with xesmf #368

Open scoatsclim opened 2 months ago

scoatsclim commented 2 months ago
import xarray as xr
import matplotlib.pyplot as plt
import intake
import cf_xarray as cfxr
import xesmf as xe
from xmip.preprocessing import combined_preprocessing

# Getting the datasets
url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
col = intake.open_esm_datastore(url)
models = ['CNRM-CM6-1-HR']
cat = col.search(table_id='Omon', grid_label='gn', experiment_id='historical', variable_id='tos', source_id=models)

# Loading data
cat.df['source_id'].unique()
z_kwargs = {'consolidated': True, 'decode_times': True} #, 'use_cftime': True}
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    dset_dict = cat.to_dataset_dict(zarr_kwargs=z_kwargs,
                                    preprocess=combined_preprocessing)

# Making and fixing coarse grid  
coarse_grid=xe.util.grid_global(1,1)
lon = coarse_grid["lon"].where(coarse_grid["lon"] > 0, 360 + coarse_grid["lon"])
lon_b = coarse_grid["lon_b"].where(coarse_grid["lon_b"] >= 0, 361 + coarse_grid["lon_b"])
lon = lon.sortby(lon[0,:])
lon_b = lon_b.sortby(lon_b[0,:])
coarse_grid=coarse_grid.assign_coords(lon_b=lon_b,lon=lon)

# Grabbing dataset
ds_in=dset_dict['CMIP.CNRM-CERFACS.CNRM-CM6-1-HR.historical.Omon.gn']
ds_in=ds_in.isel(member_id=0).squeeze()
ds_in=ds_in.sel(time=slice('1856','2014'))

# Input data setup for regrid
in_mask=ds_in['tos'][100,:,:].notnull()
ds_in_mask=ds_in.assign(mask=in_mask)
if "vertices_latitude" not in ds_in.variables:
    lat_corners=cfxr.bounds_to_vertices(ds_in.lat_verticies,'vertex')
    lon_corners=cfxr.bounds_to_vertices(ds_in.lon_verticies,'vertex')
else:
    lat_corners=cfxr.bounds_to_vertices(ds_in['vertices_latitude'],'vertex')
    lon_corners=cfxr.bounds_to_vertices(ds_in['vertices_longitude'],'vertex')
ds_in_mask.coords['lon_b']=lon_corners
ds_in_mask.coords['lat_b']=lat_corners

# Regridder
reg_mask=xe.Regridder(ds_in_mask,coarse_grid,'conservative',ignore_degenerate=True,periodic=True)

# Regrid
ds_out_siconc=reg_mask(ds_in_mask.squeeze())

# Plotting to show the seam issue:
fig, axs = plt.subplots(ncols=1, figsize=(12, 4))
inpl=ds_out_siconc['tos'][0,:,:]
inpl.plot(ax=axs, add_colorbar=True)

Julius told me he had a solution for this, sorry for the messy code. Thanks!

This code produces a seam in the regridded tos output.

Screenshot 2024-08-14 at 11 19 49 AM
jbusecke commented 2 months ago

I am fairly sure this is due to the order of the vertex. I ran into this problem in another context.

Just dropping some code I used to fix this in a brute force manner (by trying each combination and testing the resulting vertex points).

import itertools
# might be missing some imports

def cmip_bounds_to_xesmf(ds: xr.Dataset, order=None):
    # the order is specific to the way I reorganized vertex order in xmip (if not passed we get the stripes in the regridded output!

    if not all(var in ds.variables for var in ["lon_b", "lat_b"]):
        ds = ds.assign_coords(
            lon_b=cf_xarray.bounds_to_vertices(
                ds.lon_verticies.load(), bounds_dim="vertex", order=order
            ),
            lat_b=cf_xarray.bounds_to_vertices(
                ds.lat_verticies.load(), bounds_dim="vertex", order=order
            ),
        )
    return ds

def test_vertex_order(ds):
    # pick a point in the southern hemisphere to avoid curving nonsense
    p = {"x": slice(20, 22), "y": slice(20, 22)}
    ds_p = ds.isel(**p).squeeze()
    # get rid of all the unneccesary variables
    for var in ds_p.variables:
        if (
            ("lev" in ds_p[var].dims)
            or ("time" in ds_p[var].dims)
            or (var in ["sub_experiment_label", "variant_label"])
        ):
            ds_p = ds_p.drop_vars(var)
    ds_p = cmip_bounds_to_xesmf(
        ds_p, order=None
    )  # woudld be nice if this could automatically get the settings provided to `cmip_bounds_to_xesmf`
    ds_p = ds_p.load().transpose(..., "x", "y", "vertex")
    if (
        not (ds_p.lon_b.diff("x_vertices") > 0).all()
        and (ds_p.lat_b.diff("y_vertices") > 0).all()
    ):
        raise ValueError("Test vertices not strictly monotinically increasing")

def reorder_vertex(ds, new_order):
    ds_wo_vertex = ds.drop_vars([va for va in ds.variables if 'vertex' in ds[va].dims])
    ds_w_vertex = ds.drop_vars([va for va in ds.variables if 'vertex' not in ds[va].dims])
    ds_w_vertex_reordered = xr.concat([ds_w_vertex.isel(vertex=i) for i in new_order], dim='vertex')
    return xr.merge([ds_w_vertex_reordered, ds_wo_vertex])

def get_order(ds):
    order = [0, 1, 2, 3]
    all_orders = itertools.permutations(order, len(order))
    for new_order in all_orders:
        ds_reordered = reorder_vertex(ds, new_order)
        try:
            test_vertex_order(ds_reordered)
            print(f"{new_order=} worked!")
            return new_order
        except:
            pass

from xmip.utils import cmip6_dataset_id
import warnings
def test_and_reorder_vertex(ds):
    """This is an expensive check that tries every possible order of the vertex and confirms 
    that we get strictly monontonic lon_b/lat_b coordinates for a test point.
    """

    new_order = get_order(ds)
    if new_order is None:
        # drop them, maybe another one works better? This is a nightmare TBH.
        ds_out = ds.drop_vars([va for va in ds.variables if 'vertex' in ds[va].dims])
        print(f"Unable to find a vertex order for {cmip6_dataset_id(ds)}")
        # raise ValueError(f"Unable to find a vertex order for {cmip6_dataset_id(ds)}")
    else:
        print(f"Changing vertex order for {cmip6_dataset_id(ds)}")
        ds_out = reorder_vertex(ds, new_order)
    return ds_out

This cannot be the most elegant solution, but Ill try to work on this some time soon.