ORNL / HydraGNN

Distributed PyTorch implementation of multi-headed graph convolutional neural networks
BSD 3-Clause "New" or "Revised" License
61 stars 27 forks source link

Enhancement: model-agnostic layer-wise checkpointing #250

Closed LemonAndRabbit closed 3 months ago

LemonAndRabbit commented 4 months ago

Checkpointing is a widely used technique (i.e. in deepspeed) to reduce peak GPU memory consumption. It achieves this by recomputing intermediate activations during the backward pass instead of storing them from the forward pass.

This enhancement leverages PyTorch's official checkpoint implementation and applies it at a per convolution layer granuality.

Key features include:

Usage: config["NeuralNetwork"]["Training"]["conv_checkpointing"] = True

Overhead: With an 18-layer, 512-dimension PNA network, recomputing activations introduces an extra training time of ~28.6%, as observed on an 8x A5000 server.

Raw profiling results: Attached are the memory consumption and latency profiling results (google_drive). For visualization, please refer to pytorch.org/memory_viz and ui.perfetto.dev/, respectively.