openclimatefix / graph_weather

PyTorch implementation of Ryan Keisler's 2022 "Forecasting Global Weather with Graph Neural Networks" paper (https://arxiv.org/abs/2202.07575)
MIT License
179 stars 44 forks source link

Add TPU support #27

Open jacobbieker opened 2 years ago

jacobbieker commented 2 years ago

Detailed Description

Currently, the graph neural network library dependencies don't support TPUs with pytorch geometric, or don't seem to at least because of custom kernels. We could add a Jax version for TPU support? The original model was implemented in Jax apparently.

Context

Being able to use TPUs could speed up training quite a bit.

Possible Implementation

vballoli commented 1 year ago

Hey @jacobbieker, I wanted to start contributing to openclimatefix and take a shot at this if this feature is still of interest.

jacobbieker commented 1 year ago

Hi, yeah that would be awesome!

aavashsubedi commented 3 months ago

Is this still of interest? Thinking of working on it over a long time frame.
The bottleneck would be getting the graph block working but I managed to get some of the MLP components functioning. Using Flax(Linnen) + Jraph.

aavashsubedi commented 3 months ago

@jacobbieker

jacobbieker commented 3 months ago

Yeah, this is still of interest! It would be good to be able to use this on TPUs. Ideally still in PyTorch, although up for JAX as well.

aavashsubedi commented 3 months ago

Hi, Amazing! I wasn't aware that you can use XLA on PyTorch till just now! : ) For inference I guess this would be the easiest thing:

import torch_xla.utils.serialization as xser

model.load_state_dict(xser.load('model.pt'))

https://stackoverflow.com/questions/69328983/are-pytorch-trained-models-transferable-between-gpus-and-tpus So there's less mismatch between versions, will stick to PyTorch (maybe there's a cool way to wrap TPU support around each GPU model will see! ). WIll read up on XLA : )