Closed ChrisAGBlake closed 8 months ago
Hi Chris, I'm afraid it requires some further tricks not included in this codebase to train a 0.25deg model unrolled to 3 days within TPU RAM limits on TPUv4. The main ones being some additional gradient checkpointing, and offloading some gradient checkpoints from device to host RAM. Our code for this is more tied to internal infrastructure and so the scope for this initial open sourcing was mainly focused on supporting inference, and basic gradient computation for versions of the model that fit more easily into RAM.
Closing as no work imminently planned in this direction in the open source codebase, although the above might give you some hints if you wanted to roll your own.
Closing
In the paper it states that you're using the TPU v4 chips, which have 32 GB memory accessible per TPU core I believe. When trying to train the high res version (on nvidia GPU currently) I seem to use > 48 GB of VRAM for a batch size of 1 per GPU.
I'm using code in the example notebook with minimal modification.
Are there any other things that need to be done in order to get the memory useage down below 32GB for the high res model? I note that gradient checkpointing and bfloat16 are already setup in the example notebook in this function if I'm understanding it correctly?
Do the model inputs (xarray datasets) need to also be cast to a different precision instead of float32?
I'm unsure how to get the memory usage down further so that it could be run on the TPU v4 or something like an 40 GB A100.