SciTools / cartopy

Cartopy - a cartographic python library with matplotlib support
https://scitools.org.uk/cartopy/docs/latest
BSD 3-Clause "New" or "Revised" License
1.43k stars 364 forks source link

Speeding up img_transform.regrid #2144

Open nbren12 opened 1 year ago

nbren12 commented 1 year ago

Description

A common use-case for cartopy is to generate a series of maps which are then compiled into a movie (e.g. with ffmpeg).

This is slow. E.g. to transform, 0.25 deg global data to Mercator on my laptop takes about 2 seconds per frame. Most of this time is spent on a nearest-neighbors search to map between the target and source projections, but this computation only needs to happen a single time if the target and source grids are fixed.

I am proposing that img_transform.regrid be refactored into an object Regrid to achieve this. It can then be used like this:

# takes a couple of seconds
regrid = img_transform.Regrid(xs, ys, src_crs, target_crs, xt, yt)

for i in range(100):
    arr = get_frame(i)
    # almost instantaneous
    regridded = regrid(arr)
    ax.pcolormesh(regridded, transform=target_crs)

xesmf uses the same pattern to cache the expensive weight matrix computation.

I went ahead and implemented this (see below). Now I can make movies about 100x faster :) Would you be interested in a PR? Sorry for not following the template too closely...it's a feature request rather than a bug.

``` class Regrid: def __init__(self, source_x_coords, source_y_coords, source_proj, target_proj, target_x_points, target_y_points): # Stack our original xyz array, this will also wrap coords when necessary xyz = source_proj.transform_points(source_proj, source_x_coords.flatten(), source_y_coords.flatten()) # Transform the target points into the source projection target_xyz = source_proj.transform_points(target_proj, target_x_points.flatten(), target_y_points.flatten()) if _is_pykdtree: kdtree = pykdtree.kdtree.KDTree(xyz) # Use sqr_dists=True because we don't care about distances, # and it saves a sqrt. _, indices = kdtree.query(target_xyz, k=1, sqr_dists=True) else: # Versions of scipy >= v0.16 added the balanced_tree argument, # which caused the KDTree to hang with this input. kdtree = scipy.spatial.cKDTree(xyz, balanced_tree=False) _, indices = kdtree.query(target_xyz, k=1) mask = indices >= len(xyz) indices[mask] = 0 back_to_target_xyz = target_proj.transform_points(source_proj, target_xyz[:, 0], target_xyz[:, 1]) self.target_shape = target_x_points.shape desired_ny, desired_nx = self.target_shape back_to_target_x = back_to_target_xyz[:, 0].reshape(desired_ny, desired_nx) back_to_target_y = back_to_target_xyz[:, 1].reshape(desired_ny, desired_nx) # Do double transform to clip points that do not map back and forth # to the same point to within a fixed fractional offset. # NOTE: This only needs to be done for (pseudo-)cylindrical projections, # or any others which have the concept of wrapping FRACTIONAL_OFFSET_THRESHOLD = 0.1 # data has moved by 10% of the map x_extent = np.abs(target_proj.x_limits[1] - target_proj.x_limits[0]) y_extent = np.abs(target_proj.y_limits[1] - target_proj.y_limits[0]) self.non_self_inverse_points = (((np.abs(target_x_points - back_to_target_x) / x_extent) > FRACTIONAL_OFFSET_THRESHOLD) | ((np.abs(target_y_points - back_to_target_y) / y_extent) > FRACTIONAL_OFFSET_THRESHOLD)) # Transform the target points to the source projection and mask any points # that fall outside the original source domain. target_in_source_x = target_xyz[:, 0].reshape(desired_ny, desired_nx) target_in_source_y = target_xyz[:, 1].reshape(desired_ny, desired_nx) bounds = _determine_bounds(source_x_coords, source_y_coords, source_proj) outside_source_domain = ((target_in_source_y >= bounds['y'][1]) | (target_in_source_y <= bounds['y'][0])) tmp_inside = np.zeros_like(outside_source_domain) for bound_x in bounds['x']: tmp_inside = tmp_inside | ((target_in_source_x <= bound_x[1]) & (target_in_source_x >= bound_x[0])) self.outside_source_domain = outside_source_domain | ~tmp_inside self.indices = indices self.mask = mask def __call__(self, array, mask_extrapolated=False): indices = self.indices # Squash the first two dims of the source array into one temp_array = array.reshape((-1,) + array.shape[2:]) if np.any(self.mask): new_array = np.ma.array(temp_array[indices]) new_array[self.mask] = np.ma.masked else: new_array = temp_array[indices] new_array.shape = self.target_shape + (array.shape[2:]) if np.any(self.non_self_inverse_points): if not np.ma.isMaskedArray(new_array): new_array = np.ma.array(new_array, mask=False) new_array[self.non_self_inverse_points] = np.ma.masked if mask_extrapolated: if np.any(self.outside_source_domain): if not np.ma.isMaskedArray(new_array): new_array = np.ma.array(new_array, mask=False) new_array[self.outside_source_domain] = np.ma.masked return new_array ```
greglucas commented 1 year ago

Having some kind of cache for updating image data without needing to transform every time seems like a good improvement! But in your example, you are then putting that into pcolormesh() rather than imshow() which will be much faster. For pcolormesh() we actually subclass the typical QuadMesh object into a GeoQuadMesh, https://github.com/SciTools/cartopy/blob/main/lib/cartopy/mpl/geocollection.py to handle this exact updating during a loop. https://scitools.org.uk/cartopy/docs/latest/gallery/miscellanea/animate_surface.html#sphx-glr-gallery-miscellanea-animate-surface-py

It seems like doing something similar for AxesImage turning into a GeoAxesImage where the array mapping is cached on that object would be the way to go here. Then your above example would be:

image = ax.imshow(arr, transform=target_crs)
for i in range(100):
    arr = get_frame(i)
    # almost instantaneous
    image.set_array(arr)
nbren12 commented 1 year ago

I agree imshow makes more sense than quash-mesh here. Perhaps having the caching in a matplotlib subclass makes sense too for the movie application.

That said, regrid is nicely general purpose and could be useful beyond plotting. For example, for batch conversion of data between CRSs. So it could make sense to add both GeoAxesImage and img_transform.Regrid. GeoAxesImage could use the latter. Thoughts?

greglucas commented 1 year ago

I agree you'll need to do some kind of refactor to cache the needed quantities for re-use, so that may be a good option.