SeanNaren / min-LLM

Minimal code to train a Large Language Model (LLM).
MIT License
164 stars 8 forks source link

Improving Throughput with DeepSpeed #13

Closed SeanNaren closed 2 years ago

SeanNaren commented 2 years ago

The goal of this issue is to reach optimum performance with DeepSpeed Zero Stage 3.

Results are collected using a pytorch/xFormers + deepspeed single script on an 8 A100 machine. I've been running the below command on this branch.

The TFLOPs calculation is based on the equation taken from the BigScience project, which can be seen here.

I've been running a ~1.5B parameter model with block sparse attention (based on megatron paper sizes). It would be good to be able to maximise throughput for this model, and be able to scale hyper-parameters to ensure good TFLOPs for smaller variants.

deepspeed --num_gpus 8 pytorch.py  --attention blocksparse --batch_size_per_gpu 36 --precision bf16
Estimates: 124.49 TFLOPs Avg Iteration Time: 8.32s

This is closer to the rough maximum per A100 of 150 TFLOPs (300/2). A few things left on the table:

The DeepSpeed config can be seen here. Auto-scaling of bucket size and the threshold was taken from the Transformers repo (under their "auto" configuration). Seems to prevent these cache allocator warnings which improved throughput!

    config = {
        "zero_allow_untested_optimizer": True,
        "zero_optimization": {
            "stage": stage,
            "contiguous_gradients": True,
            "overlap_comm": True,
            "allgather_partitions": True,
            "reduce_scatter": True,
            "reduce_bucket_size": n_embd * n_embd,
            "stage3_prefetch_bucket_size": 0.9 * n_embd * n_embd,
            "stage3_param_persistence_threshold": 10 * n_embd,
        },
        "gradient_clipping": 1,
        "train_micro_batch_size_per_gpu": batch_size_per_gpu,
        "bf16": {"enabled": precision == "bf16"},
        "fp16": {"enabled": precision == 16}
    }

cc @tjruwase @jeffra who might be able to assist!

SeanNaren commented 2 years ago

Huge disparity with PyTorch DeepSpeed vs PyTorch Lightning:

PyTorch Lightning: 119.65TFLOPs PyTorch + Deepspeed: 129.24TFLOPs

Roughly an 8% performance improvement which is pretty substantial. Unfortunately, I do not have the time to look into the difference in speed, so will probably move forward with the PyTorch/DeepSpeed combination.

SeanNaren commented 2 years ago

I've merged a bunch of changes into the main branch, including a rename of the project. I'm going to re-open a few issues to track things!