LambdaLabsML / distributed-training-guide

Best practices & guides on how to write distributed pytorch training code
MIT License
276 stars 18 forks source link

Add ZeroRedundancyOptimizer to chapters 2 & 3 #44

Open corey-lambda opened 2 weeks ago

corey-lambda commented 2 weeks ago

Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer

-    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
+    optimizer = ZeroRedundancyOptimizer(
+        model.parameters(),
+        optimizer_class=torch.optim.AdamW,
+        lr=args.lr,
+        fused=True
+    )

Very easy to use and immediately reduces memory usage.

corey-lambda commented 2 weeks ago

This also needs some updates to saving checkpoints:

 if state["global_step"] % args.ckpt_freq == 0:
+    optimizer.consolidate_state_dict(to=0)
     if rank == 0:
         torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")

However, HUGE CAVEAT:

The consolidate_state_dict transfers between single pair of GPUs at a time. It is VERY slow with llama 8B (taking minutes per GPU).

Not sure if should be recommended for this reason.