openclimatefix / graph_weather

PyTorch implementation of Ryan Keisler's 2022 "Forecasting Global Weather with Graph Neural Networks" paper (https://arxiv.org/abs/2202.07575)
MIT License
197 stars 50 forks source link

GenCast's Processor #119

Closed gbruno16 closed 4 months ago

gbruno16 commented 4 months ago

Pull Request

Description

From the paper:

The Processor is a graph transformer model operating on a spherical mesh that computes neighbourhood-based self-attention. Unlike the multimesh used in GraphCast, the mesh in GenCast is a 6-times refined icosahedral mesh as defined in Lam et al. (2023), with 41,162 nodes and 246,960 edges. The Processor consists of 16 consecutive standard transformer blocks (Nguyen and Salazar, 2019; Vaswani et al., 2017), with a feature dimension equal to 512. The 4-head self-attention mechanism in each block is such that each node in the mesh attends to itself and to all other nodes within its 32-hop neighbourhood on the mesh.

Transformers

In this PR, there are two different versions of the transformer blocks:

Conditional Layer Normalization

Every LayerNorm layer is replaced by a custom module: an element-wise affine transformation is applied to the output of the LayerNorm, with parameters computed as linears of Fourier embeddings of noise levels.

K-hop Neighbours

The k-hop mesh graph is now computed using sparse multiplications of the adjacency matrix instead of relying on the PyG implementation.