mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
325 stars 63 forks source link

Support FSDP in PyTorch #796

Open priyakasimbeg opened 3 hours ago

priyakasimbeg commented 3 hours ago

It is useful to shard optimizer state across devices (to save significant memory). This reflects current practice. We want to support it.

priyakasimbeg commented 3 hours ago

From meeting minutes from Michael Shi: Challenge is ensuring that JAX and PyTorch are equivalent. PyTorch should be doable by changing the DDP wrapper to the FSDP wrapper.