LambdaLabsML / distributed-training-guide

Best practices & guides on how to write distributed pytorch training code
MIT License
294 stars 19 forks source link

Add ZeroRedundancyOptimizer to chapters 2 & 3 #44

Open corey-lambda opened 1 month ago

corey-lambda commented 1 month ago

Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer

optimizer = ZeroRedundancyOptimizer(
        model.parameters(),
        optimizer_class=torch.optim.AdamW,
        lr=args.lr,
        fused=True
)

Very easy to use and immediately reduces memory usage.

corey-lambda commented 1 month ago

This also needs some updates to saving checkpoints:

 if state["global_step"] % args.ckpt_freq == 0:
+    optimizer.consolidate_state_dict(to=0)
     if rank == 0:
         torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")

However, HUGE CAVEAT:

The consolidate_state_dict transfers between single pair of GPUs at a time. It is VERY slow with llama 8B (taking minutes per GPU).

Not sure if should be recommended for this reason.