google-deepmind / graphcast

Apache License 2.0
4.53k stars 572 forks source link

GPU / TPU memory requirements for training #55

Closed ChrisAGBlake closed 8 months ago

ChrisAGBlake commented 8 months ago

In the paper it states that you're using the TPU v4 chips, which have 32 GB memory accessible per TPU core I believe. When trying to train the high res version (on nvidia GPU currently) I seem to use > 48 GB of VRAM for a batch size of 1 per GPU.

I'm using code in the example notebook with minimal modification.

Are there any other things that need to be done in order to get the memory useage down below 32GB for the high res model? I note that gradient checkpointing and bfloat16 are already setup in the example notebook in this function if I'm understanding it correctly?

def construct_wrapped_graphcast(model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
        # Deeper one-step predictor.
        predictor = graphcast.GraphCast(model_config, task_config)

        # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
        # from/to float32 to/from BFloat16.
        predictor = casting.Bfloat16Cast(predictor)

        # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
        # BFloat16 happens after applying normalization to the inputs/targets.
        predictor = normalization.InputsAndResiduals(
            predictor,
            diffs_stddev_by_level=diffs_stddev_by_level,
            mean_by_level=mean_by_level,
            stddev_by_level=stddev_by_level)

        # Wraps everything so the one-step model can produce trajectories.
        predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
        return predictor

Do the model inputs (xarray datasets) need to also be cast to a different precision instead of float32?

I'm unsure how to get the memory usage down further so that it could be run on the TPU v4 or something like an 40 GB A100.

mjwillson commented 8 months ago

Hi Chris, I'm afraid it requires some further tricks not included in this codebase to train a 0.25deg model unrolled to 3 days within TPU RAM limits on TPUv4. The main ones being some additional gradient checkpointing, and offloading some gradient checkpoints from device to host RAM. Our code for this is more tied to internal infrastructure and so the scope for this initial open sourcing was mainly focused on supporting inference, and basic gradient computation for versions of the model that fit more easily into RAM.

mjwillson commented 8 months ago

Closing as no work imminently planned in this direction in the open source codebase, although the above might give you some hints if you wanted to roll your own.

mjwillson commented 8 months ago

Closing