google-research / dinosaur

Apache License 2.0
210 stars 14 forks source link

Semi-Lagrangian advection #55

Open shoyer opened 3 days ago

shoyer commented 3 days ago

Semi-Lagrangian advection is a key part of efficient spectral atmospheric models.

In ECMWF’s model, it allows for 6x large time-steps compared to standard (Eulerian) advecton, corresponding to a 6x faster dynamical core. The dycore is currently the main bottleneck for NeuralGCM, so switching to semi-Lagrangian in principle would allow up to 6x faster training and inference for NeuralGCM as well.

Efficiently implementing semi-Lagrangian advection on TPUs presents a challenge, because TPUs do not support efficient irregular indexing operations (GPUs have similar performance challenges, although I believe they are somewhat less severe). Instead, we would need to break indexing into two parts:

  1. Regular indexing (same for all layers) along individual axes, to gather neighbor values. On a TPU, this is likely most efficiently implemented with matrix multiplication with a matrix of ones/zeros. Ideally this matrix should be fixed across as many dimensions as possible.
  2. Applying fixed sized stencil operation to the collected neighbors. Here we can use arbitrarily varying weights.

Step (1) is problematic with our current NeuralGCM “full” Gaussian grids, because semi-Lagrangian advection near the poles may require interpolating over many equiangular grid cells in the longitudinal direction, so the set of neighbor cells cannot be determined in a static fashion. Instead, we would need to switch to some form of reduced Gaussian grid, like ECMWF’s octahedral Gaussian grid or HEALPix. These grids all use a reduced number of longitudinal grid points near the poles, which should allow for limiting advection to a static list of neighboring cells: image

(copied from @milankl's post from https://github.com/SpeedyWeather/SpeedyWeather.jl/issues/112#issuecomment-1219599679)

HEALPix4 looks especially appealing, due to its relatively simple geometry. For HEALPix4, If we represent our horizontal grids as shown below, we could compute a “padded” representation that would suffice as inputs for stencil operations, merely via indexing across the longitudinal dimension, with the same weights for all longitudes. This would look something like: image

The implementation of the horizontal indexing operation to create the "padded" representation would be equivalent to jnp.einsum operation with indices "ijk,...ij->...kj" where the first argument is a static 3D array. We know we can implement operations like this efficiently on GPU/TPU because this the same structure as the Legendre transforms that we currently implement inside spherical harmonics.

Once we have the basic primitives for efficient interpolation, there is still a fair amount of tricky 3D interpolation to work out. Hopefully we could simply implement techniques similar to those currently used in ECMWF's model.

The NeuralGCM currently has no concrete plans to implement this feature, but it would be quite welcome! This would be a somewhat involved project, so if you're interested in helping out please do to reach out to discuss.

milankl commented 3 days ago

For HEALPix4, If we represent our horizontal grids as shown below, we could compute a “padded” representation that would suffice as inputs for stencil operations, merely via indexing across the longitudinal dimension, with the same weights for all longitudes. This would look something like:

I call this grid now OctaHEALPixGrid due to its similarity with the octahedral Gaussian grid. Note that you can represent all data on the sphere on this grid as a square matrix

image

and a 3x3 stencil can be continuously defined (not the case for the original HEALPix grid where one neighbour is missing in the corners that are somewhat equivalent to the corners of the cubed sphere). But boundary conditions are a bit funky on the south pole: cell 64 here has neighbours 61, 62, 63 as if it was just normal biperiodic boundary conditions but then also weirdly 53 and 58 next to the obvious 51, 59, 60. And a 3x3 stencil changes direction across the globe.

shoyer commented 3 days ago

Note that you can represent all data on the sphere on this grid as a square matrix

On TPU/GPU, funky data access patterns can be quite slow, so I still lean towards the redundant data representation that is still along latitude/longitude dimension, despite the 100% storage overhead.

We currently use a similar representation for storing the state in spherical harmonics coefficients m,l currently (despite the structural zeros for abs(m) > l) because applying homogeneous operations along different array axes is much faster than packing/unpacking data from a compressed representation.

Thanks Milan for adding your insights!