ServiceNow / Fast-LLM

Accelerating your LLM training to full speed
https://servicenow.github.io/Fast-LLM/
Other
37 stars 5 forks source link

Speed up checkpoint serialization #26

Open jlamypoirier opened 3 weeks ago

jlamypoirier commented 3 weeks ago

🐞 Describe the Bug

I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.

Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.

On H100 NVMe, we should have more than 2 GiB/s of write speed, but torch.save and safetensors.torch.save_file give less than a third of that.

I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.

This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.

πŸ”„ Steps to Reproduce

Third-party library benchmarks

Running this code:

import contextlib
import pathlib
import shutil

import torch, time
import numpy as np
import safetensors.torch

size=2**30
a=torch.ones(size, dtype=torch.uint8, device="cuda")
b=a.cpu()
c=b.numpy()

dir=pathlib.Path("tmp")
#dir=pathlib.Path("/mnt/workspace/tmp/ckpt")

if dir.is_dir():
    shutil.rmtree(dir)
dir.mkdir(exist_ok=True)

@contextlib.contextmanager
def measure_time(name):
    start = time.time()
    yield
    stop = time.time()
    print(f"{name}: {stop-start:.3f}s, {size/2**20/(stop-start):.3f} MiB/s")

with measure_time("torch save from gpu"):
    torch.save(a, (dir/"torch.pt").open("wb"))

with measure_time("torch save from cpu"):
    torch.save(b, (dir/"torch1.pt").open("wb"))

with measure_time("numpy save"):
    np.save((dir/"np").open("wb"), c)

with measure_time("safetensors save from gpu"):
    safetensors.torch.save_file({"a":a},dir/"c.safetensors")

with measure_time("safetensors save from cpu"):
    safetensors.torch.save_file({"a":a},dir/"c1.safetensors")

with measure_time("safetensors serialize in-memory from cpu"):
    d=safetensors.torch.save({"a":b})

with measure_time("safetensors serialize in-memory from gpu"):
    d=safetensors.torch.save({"a":a})

with measure_time("Write to disk"):
    with (dir / "d.safetensors").open("wb") as f:
        f.write(d)

We get:

torch save from gpu: 1.338s, 765.159 MiB/s
torch save from cpu: 1.641s, 623.839 MiB/s
numpy save: 0.536s, 1909.272 MiB/s
safetensors save from gpu: 1.558s, 657.230 MiB/s
safetensors save from cpu: 2.307s, 443.877 MiB/s
safetensors serialize in-memory from cpu: 1.658s, 617.791 MiB/s
safetensors serialize in-memory from gpu: 2.092s, 489.483 MiB/s
Write to disk: 0.497s, 2062.189 MiB/s

I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.

Fast-LLM benchmarks

Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:

2024-10-25 16:14:10,223 [Rank 00] Saving checkpoint at iteration 100
2024-10-25 16:14:17,987 [Rank 00] Saved checkpoint to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/checkpoints/100
2024-10-25 16:14:17,989 [Rank 00] Saving export at iteration 100
2024-10-25 16:14:17,999 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_0.safetensors
2024-10-25 16:14:31,599 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_1.safetensors
2024-10-25 16:14:40,670 [Rank 00] Saving index to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict.safetensors.index.json
2024-10-25 16:14:40,679 [Rank 00] Saved export to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100

Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.

The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.

πŸ“œ Environment Information

DGX-H100, saving locally (NVMe).

jlamypoirier commented 3 weeks ago

Distributed checkpoints files are just a single big tensor, so the safetensors and torch formats are overkill. We can get away with a pure binary format like numpy, which has near-optimal write speed and can be converted to/from torch without copy. I got 987 MiB/s with this including the cpu-gpu conversion, which is still way faster than safetensors. We could do better by buffering to overlap the transfer and save, but it doesn't matter because we'll want async checkpoints anyway.

So I propose the new format for distributed checkpoints in v0.2:

This means dropping backward compatibility, but we were already planning to do that. For the future I think we should always prevent loading distributed checkpoints with a different version, and enforce the state_sict format for long-term storage.

tscholak commented 3 weeks ago

I'd rather address this in 0.3 and put #25 in 0.2, mostly because #26 affects the internal checkpoint format, whereas #25 is user facing and improves usability signficantly. Let's discuss next week.

tscholak commented 2 weeks ago

re: speed. loading can be sped up, too: https://x.com/_philschmid/status/1853339965612060797