microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.39k stars 4.11k forks source link

[math] what network throughput is required to handle ZeRO-3 traffic? #2928

Open stas00 opened 1 year ago

stas00 commented 1 year ago

Given a model size and number of gpus, how can we calculate what kind of throughput should the interconnect network have to handle ZeRO-3 traffic. Is 100Gbps enough? or does one need 1_000Gbps?

Users need to know what kind of requirements they need to seek out of the setup they buy or rent in order not to be network-bound and overpay for idling gpus.

Of course, this would also require the knowledge of gpus since in order not to be network-bound we need to ensure that comms <= compute+memory movement. We are of course discussing the compute/comms overlap here.

But even w/o having the knowledge of compute timing, it should be straightforward to calculate that to train a model of size X on Y gpus, this much traffic will be incurred for zero shards and that much for various reductions.

For ZeRO inference the need would be just zero shards traffic.

Thank you.

Anecdotally, when we were choosing which framework to choose to train BLOOM-176 we had none of these numbers and had to benchmark the actual cluster, and measure the overall throughput., which for many users can be very difficult to procure before they commit to buying/renting hardware. It'd have helped a lot to know that for 176B parameter models using ZeRO3 on 384 gpus it'd require that many Gbps network.

GuanhuaWang commented 1 year ago

Hi @stas00 , Thanks for raising this interesting question. 

The tl;dr answer is, to get reasonable GPU throughput when training at scale (64+GPUs), 100 Gbps is not enough, 200-400 Gbps is ok, 800-1000 Gbps will be ideal. 

Hardware cost-effectiveness: Given the price that InfiniBand (IB) is usually more expensive than ethernet, 200 to 400/800 Gbps ethernet link seems to be a more cost-effective solution. RDMA over Converged Ethernet (RoCE) could achieve similar throughput performance comparing with IB (but slight longer latency for small message passing). 

Below is the math 

Notation: model size M, num of GPU per node is G, num of Node is N, in total G*N GPUs in use

Assumption: intra-node GPU communication overhead is ignored (Since NVLink/NVSwitch are high-throughput links)

In ZeRO-3, we have all-gather on weights (M) in forward, then all-gather on weights (M) in backward, last is reduce-scatter on gradients (M) in backward. In total 3 global collective calls.

For each of above 3 collectives, each GPU need to sent out M/(G*N) data outside the node as cross-node traffic. Each node need to send out M / N

Given that we usually use fp16 (2 bytes) to represent both weights and graidents, for each collective, each node send out 2M/N Bytes. 3 collectives in total needs each node to send out 6M/N Bytes, which is equal to 8 6 M/N = 48 M / N bits. 

Numbers we collected over 384 V100 GPUs (24 DGX-2 nodes) and 176B model is,

We also notice that when training at scale, the communication overhead is more pronounced with small micro-batch size per GPU. And we may not be able to increase micro-batch size since global-batch size is often fixed to achieve good model convergence rate.  We are trying to solve this issue with our up-coming new release project called ZeRO++, which could achieve better e2e system throughput when training at scale with small micro-batch size using ZeRO-3. Stay tuned!

stas00 commented 1 year ago

whoah! This is a priceless sharing, @GuanhuaWang - XieXie!

Can we do a variation for bf16, which is absolutely taking over fp16 as we speak for LLM. Please note that deepspeed is changing to default to fp32 reductions traffic for bf16 (one of the upcoming PRs by @tjruwase) the rest should be the same as fp16.

@jeffra/@tjruwase could we put this in documentation? Perhaps next to https://deepspeed.readthedocs.io/en/latest/memory.html, but call it network requirements doc?

thomasw21 commented 1 year ago

That's awesome @GuanhuaWang

What do you mean with micro-batch when using ZeRO? I was under the impression that micro batches are only relevant when you use pipelining, but if you're not you might as well send the largest batches.

If we keep your reasoning for the compute/communication tradeoff:

You would need something like batch_size * seq_len > 1.12 * 1e3 (assuming that every node is interconnected with that bandwidth)

This should get worse with A100:

You would need something like batch_size * seq_len > 3.12 * 1e3 (assuming that every node is interconnected with that bandwidth)

If you have the worst bandwidth:

You would need something like batch_size * seq_len > 24.96 * 1e3 (assuming that every node is interconnected with that bandwidth)

Everything is pure speculation, I haven't ran anything yet to double check my math.

tjruwase commented 1 year ago
  • Computation for each forward backward per node is 3 * 2 * M * T / N flops, T being tokens, rougly.

@thomasw21, why is computation divided by N?

thomasw21 commented 1 year ago

Per node

Edit: ah actually you're right we have to divide by dp size instead

Edit: ah no it's multiply by dp size and then divide by nodes

tjruwase commented 1 year ago

Okay, so T is the global number of tokens, not per node or per model replica?

thomasw21 commented 1 year ago

Okay actually scratch everything I said above. Yes T is global number of tokens.

GuanhuaWang commented 1 year ago

That's awesome @GuanhuaWang

What do you mean with micro-batch when using ZeRO? I was under the impression that micro batches are only relevant when you use pipelining, but if you're not you might as well send the largest batches.

If we keep you're reasoning for the compute/communication tradeoff:

  • Communication for each forward backward per node is 6 M / N byte
  • Computation for each forward backward per node is 3 * 2 * M * T / N flops, T being global number of tokens, rougly.

This should mean that you need to set T to something higher than the GPU performance (TFLOPS) and bandwidth ratio

Assuming the following setups:

  • V100: 112 TFLOPs
  • Bandwitdh: 800 Gbps = 100 GBps
  • Infinite memory (otherwise there's new considerations to take in account, ie activation memory)

You would need something like batch_size * seq_len > 1.12 * 1e3 (assuming that every node is interconnected with that bandwidth)

This should get worse with A100:

  • A100: 312 TFLOPs
  • Bandwitdh: 800 Gbps = 100 GBps
  • Infinite memory (otherwise there's new considerations to take in account, ie activation memory)

You would need something like batch_size * seq_len > 3.12 * 1e3 (assuming that every node is interconnected with that bandwidth)

If you have the worst bandwidth:

  • A100: 312 TFLOPs
  • Bandwitdh: 100 Gbps = 12.5 GBps
  • Infinite memory (otherwise there's new considerations to take in account, ie activation memory)

You would need something like batch_size * seq_len > 24.96 * 1e3 (assuming that every node is interconnected with that bandwidth)

Everything is pure speculation, I haven't ran anything yet to double check my math.

Hi @thomasw21 , thanks for the detailed reply and math.

Sorry for the naming confusion with pipeline parallelism. Micro-batch here I mean per-GPU batch size, say we have global batch size of X, num of GPUs is Y. Then per-GPU micro-batch is X/Y, given that ZeRO indeed mimic data parallelism. And X is often fixed for a model but Y can change.

GPU theoretical TFLOP ceiling is hard to achieve given all scheduling/networking overhead and limited on-device memory. Practically, we believe over 30+ TFLOPs per V100 is a usable case. If we assume infinite memory (which could never be true), then I agree with the math here.

wptoux commented 1 year ago

Why is the amount of communication between nodes M/N? After all, each node needs to get the parameters on all other nodes, which looks like M * (N - 1) / N. And wouldn't that be a bit counter-intuitive if I had a very large cluster, let's say 1000 nodes, and the communication overhead per node would be small.

stas00 commented 1 year ago

@GuanhuaWang, @jeffra - let's revive this thread and give it a higher priority if you're willing to support that - the main question I'm being asked very often these days is what internode bandwidth is required to choose Deepspeed over TP+PP+DP scalability frameworks.

So I send users to https://github.com/microsoft/DeepSpeed/issues/2928#issuecomment-1463041491 but that information is for V100 and thus very much outdated.

I'd expect higher requirements for A100 nodes, and, of course, we are also migrating to H100s across everywhere.

Thank you very much!

Anecdotally we trained IDEFICS-80B with 340Gbps internode and we were able to get only 90TFLOPs on A100 nodes, as compared to 150TFLOPs we were getting with BLOOM-176 on Megatron-Deepspeed on only 140Gbps network.

stas00 commented 1 year ago

Also as I started reproducing this math, there are many more things to take into an account here with regards to the 3x multiplier. which in https://github.com/microsoft/DeepSpeed/issues/2928#issuecomment-1463041491 is 2b+2b+2b (fwd+bwd+grad).

Such as:

And so now we need to translate this to A100s being 3x faster and H100s being 9x faster compared to V100. And let's not even go into fp8 compute just yet.

So now we want that the comms <= compute so that the comms aren't the bottleneck so if 3x was the multiplier then a very rough projection for great throughput will require 800*9 => 5.6Tbps which none of the H100 node providers will supply - at best you will get 3.2Tbps peak spec. And now with recomputation and fp32 reduction this is an almost another double of requirements of 10Tbps.

So either I'm missing something in my math, or something is wrong in the original benchmark.

And of course if possible making an up-to-date A100 and H100 recommendation benchmark will hugely help the community.

e.g. I'm practically trying to make a decision whether it's an efficient choice to go with ZeRO or not for the upcoming H100s training, but I don't yet have the GPUs to do any measurements. And you probably do so please help us out to choose Deepspeed ;)

stas00 commented 1 year ago

I was able to confirm with @samyam that dividing by the number of Nodes is incorrect in the math of https://github.com/microsoft/DeepSpeed/issues/2928#issuecomment-1463041491

You can find the correct math here: https://github.com/stas00/ml-engineering/tree/master/model-parallelism#inter-node-speed-requirements-to-use-zero

deanpeterson commented 5 months ago

@GuanhuaWang I've been using OpenShift AI and running some deepspeed tests with Ray.io on 6 nodes connected by a 10gb network. I've noticed a training that is supposed to take max 36gb total of vram is using over 20gb of vram on each of the 6 nodes for a total of over 120gb. Can the excess use of vram be due to the slow network speed?