openclimatefix / graph_weather

PyTorch implementation of Ryan Keisler's 2022 "Forecasting Global Weather with Graph Neural Networks" paper (https://arxiv.org/abs/2202.07575)
MIT License
188 stars 47 forks source link

More efficent way of encoding input graph/output graph? #47

Open jacobbieker opened 1 year ago

jacobbieker commented 1 year ago

Currently, one of the issues with this implementation is that when there are a large amounts of input lat/lon coordinates (such as a 1 deg grid or smaller), the graphs describing the connections between the inputs and the latent graph become huge, and the model has a hard time fitting on a GPU, especially with any batch size larger than 1. It seems like there should be a better way of encoding the inputs into the latent graph than the way that I wrote in this repo, not sure how yet though.

assafshouval commented 1 year ago

Hi, thinking if to take this issue. The graph you're talking about is the Encoder.graph member?

jacobbieker commented 1 year ago

Yes! That's the one. There have now been some other implementations of similar models that might be helpful, primarily Nvidia's GraphCast implementation here: https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_net.py