Open corey-lambda opened 2 weeks ago
This also needs some updates to saving checkpoints:
if state["global_step"] % args.ckpt_freq == 0:
+ optimizer.consolidate_state_dict(to=0)
if rank == 0:
torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")
However, HUGE CAVEAT:
The consolidate_state_dict transfers between single pair of GPUs at a time. It is VERY slow with llama 8B (taking minutes per GPU).
Not sure if should be recommended for this reason.
Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer
Very easy to use and immediately reduces memory usage.