Open jlamypoirier opened 3 weeks ago
Distributed checkpoints files are just a single big tensor, so the safetensors and torch formats are overkill. We can get away with a pure binary format like numpy, which has near-optimal write speed and can be converted to/from torch without copy. I got 987 MiB/s with this including the cpu-gpu conversion, which is still way faster than safetensors. We could do better by buffering to overlap the transfer and save, but it doesn't matter because we'll want async checkpoints anyway.
So I propose the new format for distributed checkpoints in v0.2:
metadata.yaml
exclusively for checkpoint metadata.This means dropping backward compatibility, but we were already planning to do that. For the future I think we should always prevent loading distributed checkpoints with a different version, and enforce the state_sict format for long-term storage.
I'd rather address this in 0.3 and put #25 in 0.2, mostly because #26 affects the internal checkpoint format, whereas #25 is user facing and improves usability signficantly. Let's discuss next week.
re: speed. loading can be sped up, too: https://x.com/_philschmid/status/1853339965612060797
π Describe the Bug
I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.
Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.
On H100 NVMe, we should have more than 2 GiB/s of write speed, but
torch.save
andsafetensors.torch.save_file
give less than a third of that.I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.
This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.
π Steps to Reproduce
Third-party library benchmarks
Running this code:
We get:
I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.
Fast-LLM benchmarks
Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:
Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.
The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.
π Environment Information
DGX-H100, saving locally (NVMe).