microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.86k stars 4.05k forks source link

How to reproduce BERT perf results in deepspeed blog #272

Open LiweiPeng opened 4 years ago

LiweiPeng commented 4 years ago

Hi Deepspeed team,

The BERT perf results from blog Microsoft DeepSpeed achieves the fastest BERT training time are very impressive. However I couldn't reproduce the perf results in Figure 1 in the blog.

For example, using the latest nvbert code, for bert large model max-seq-len=128, the max batch size I got is 136 with 194.57 examples/s. However, in figure1, nvbert's max batch size is about 82 with 215 examples/s. The nvbert perf is 10% better than what I got.

Can you share the detailed parameters (and/or code) for Figure 1 in the blog on how to reproduce the nvbert perf results and huggingface bert perf results? Thanks.

Thanks. Liwei

RezaYazdaniAminabadi commented 4 years ago

Hi Liwei,

Thanks for trying out the deepspeed transformer kernel. I assume you use a 32-GB NVIDIA V100 GPU for this experiment, am I right? I wonder what total batch size you are using. I mean if it is with gradient accumulation of 1 or higher. Also, is it a post-ln or pre-ln BERT layer that you are running the training with? We use the pre-ln BERT structure and we set the gradient accumulation to 10 for all our experiments. One last thing is that which dataset you are using for the training?

Thanks. Reza

LiweiPeng commented 4 years ago

@RezaYazdaniAminabadi , Thanks for your quick response.

To answer your questions:

One difference from your test is that I am using post-ln not pre-ln. I am not sure whether this can explain the max-batch size difference and perf difference or not. I'll try pre-ln later.

Thanks. Liwei

RezaYazdaniAminabadi commented 4 years ago

Thanks for the reply. If you want to try pre-ln model, you may use the example in deepspeed: https://github.com/microsoft/DeepSpeedExamples/blob/8610e5e3fcce5fb247e3b85ea2bed0f2296b5443/bing_bert/nvidia/modelingpreln.py Regarding the performance difference, the pre-ln may have some performance difference than post-ln. Another difference is that we use deepspeed to run the nvidia modeling. You can try the deepspeed example: https://github.com/microsoft/DeepSpeedExamples/tree/8610e5e3fcce5fb247e3b85ea2bed0f2296b5443/bing_bert To run the pretraining, you can use ds_train_bert_bsz64k_seq128.sh script and change the batch size in deepspeed_bsz64k_lamb_config_seq128.json.

LiweiPeng commented 4 years ago

Thanks Reza. I'll try the deepspeed example way.

JF-D commented 4 years ago

@RezaYazdaniAminabadi I tried DeepSpeed bing_bert example, but I only got 200 samples/sec with batch size 64 and gradient accumulation of 10 for 128 max sequence length. I use a 32GB V100 GPU and dummy dataloader. I am quite confused about this result. image

Besides, for distributed BERT Large training, I discovered that the communication time is quite long. I wonder how did you get that high speedup ration.

-- update by enabling attention dropout checkpoint and normalize invertible in transformer kernel, I got 5% gain. But it's still far away from the deepspeed result. image

JF-D commented 4 years ago

@LiweiPeng Hi, did you try the deepspeed example and get the same perf results in deepspeed blog?

LiweiPeng commented 4 years ago

@JF-D I haven't got a chance to try the deepspeed example yet. will update once I try it.

RezaYazdaniAminabadi commented 4 years ago

Hi @JF-D, Thanks for trying out the deepspeed transformer kernel. I just ran the DeepSpeedExample with setting the --deepspeed_transformer_kernel flag in bash ds_train_bert_bsz64k_seq128.sh script. Also, I changed the train_batch_size from 64K to 640 (considering gradient accumulation of 10) in deepspeed json config file. I am using the rest of the configuration as default. I can get much higher samples/second, around 257. I achieve this result without setting any other optimization flags. image You can further increase the performance by setting the flags you mentioned and increase the batch size. Can you please share more information for your run? Like, what configuration you are using and also the script you run training with? For the distributed training, can you please set the wall_clock_breakdown flag in deepspeed config to true? This way, we can look into more detail of your training system, like what are the forward, backward and all-reduce time. After I get this information from you, I can help you to get the desired performance from deepspeed! :)

Thanks. Reza

RezaYazdaniAminabadi commented 4 years ago

By the way, if you even set the gradient accumulation to 1, using 1 NVIDIA V100 GPU (32GB), you should get a high samples/sec. image

JF-D commented 4 years ago

@RezaYazdaniAminabadi Thanks for your quick response. I just follow your configuration except that I made some changes to the dataloader. But I used a dummy dataloader to avoid its effects.

I just ran the DeepSpeedExample with setting the --deepspeed_transformer_kernel flag in bash ds_train_bert_bsz64k_seq128.sh script. Also, I changed the train_batch_size from 64K to 640 (considering gradient accumulation of 10) in deepspeed json config file. I am using the rest of the configuration as default.

After setting the wall_clock_breakdown flag in deepspeed config to true, I get this log. image

For other info, I use pytorch 1.5, cuda 10.1, cudnn 7.6.5, nccl 2.6.4. And here is my GPU info. image

RezaYazdaniAminabadi commented 4 years ago

Hi @JF-D,

Thanks for providing more information of your experiment. Clearly, we have some hardware differences here. I am using a DGX-2. Also, I am performing the evaluation using Pytorch 1.2. You can find the GPU info as follows. In my experiment, I am seeing the power is on average around 340 Watt, and the GPU utilization is >98%.

image

The FWD/BWD time you've got is almost 11.5% higher than what I am seeing:

image

Also, I realize you made some changes at the data-loader. To speed up the data loading, we use an asynchronous worker on CPU which prefetches the data in advance. So, my suggestion is that please use a fresh version of our DeepSpeed example and please check in the new changes of deepspeed, and try the experiment again. Let's see if your numbers improve! :) Thank.

By the way, you may want to turn on the --stochastic-mode in the training script to boost the performance of the transformer kernel.

Best regards, Reza

JF-D commented 4 years ago

@RezaYazdaniAminabadi Thanks a lot. I have a few more questions and I look forward to receiving your answers.

  1. It seems that the difference between our speed is caused by the different hardware (V100 SXM3 vs SXM2). I don't have dgx2 and I don't know much about their difference in training networks. Do you think that this difference is normal, and dgx2 can be 10% faster?
  2. For your training, did you use NVMe? I didn't use NVMe, but it seems that using nvme can get some performance gain. I am not sure.
  3. I discover that my gpu doesn't seem to be fully utilized (the GPU utilization sometimes < 50%). There are occasional fluctuations in the speed of the log and the time becomes longer. I think this is abnormal, there may be some problems with the machine, right? image
  4. I am quite interested in your allreduce time for large scale distributed training. The BERT Large model have ~330m parameters. Thus, for each iteration, the communication volume is ~660MB. Deepspeed can train BERT Large in 44mins with 1024 GPUs and that's an amazing result. Did you do other optimizations for allreduce? In my opinion, allreduce time is a big bottleneck in large-scale scenario.

Looking forward to your reply!

Thanks, Jiangfei

TonyTangYu commented 4 years ago

Hi, @LiweiPeng

I am trying to run the examples on the dataset wikipedia and bookcorpus. However, the official deepspeed website says that the Downloading and pre-processing instructions of these datasets are coming soon. I notice that you conduct some experiments on these two datasets. Could you please offer me some pre-processing instructions of these two datasets for me to run the DeepSpeed/BERT Pre-training examples?

Thanks! Tony

tjruwase commented 4 years ago

@TonyTangYu, can you try the nvidia datasets? Please see relevant information here and here.

TonyTangYu commented 4 years ago

@tjruwase , thanks for your reply! I will give it a try! Thanks again!

gongjingcs commented 3 years ago

@RezaYazdaniAminabadi Thanks a lot. I have a few more questions and I look forward to receiving your answers.

  1. It seems that the difference between our speed is caused by the different hardware (V100 SXM3 vs SXM2). I don't have dgx2 and I don't know much about their difference in training networks. Do you think that this difference is normal, and dgx2 can be 10% faster?
  2. For your training, did you use NVMe? I didn't use NVMe, but it seems that using nvme can get some performance gain. I am not sure.
  3. I discover that my gpu doesn't seem to be fully utilized (the GPU utilization sometimes < 50%). There are occasional fluctuations in the speed of the log and the time becomes longer. I think this is abnormal, there may be some problems with the machine, right? image
  4. I am quite interested in your allreduce time for large scale distributed training. The BERT Large model have ~330m parameters. Thus, for each iteration, the communication volume is ~660MB. Deepspeed can train BERT Large in 44mins with 1024 GPUs and that's an amazing result. Did you do other optimizations for allreduce? In my opinion, allreduce time is a big bottleneck in large-scale scenario.

Looking forward to your reply!

Thanks, Jiangfei

@RezaYazdaniAminabadi Thanks a lot. I have a few more questions and I look forward to receiving your answers.

  1. It seems that the difference between our speed is caused by the different hardware (V100 SXM3 vs SXM2). I don't have dgx2 and I don't know much about their difference in training networks. Do you think that this difference is normal, and dgx2 can be 10% faster?
  2. For your training, did you use NVMe? I didn't use NVMe, but it seems that using nvme can get some performance gain. I am not sure.
  3. I discover that my gpu doesn't seem to be fully utilized (the GPU utilization sometimes < 50%). There are occasional fluctuations in the speed of the log and the time becomes longer. I think this is abnormal, there may be some problems with the machine, right? image
  4. I am quite interested in your allreduce time for large scale distributed training. The BERT Large model have ~330m parameters. Thus, for each iteration, the communication volume is ~660MB. Deepspeed can train BERT Large in 44mins with 1024 GPUs and that's an amazing result. Did you do other optimizations for allreduce? In my opinion, allreduce time is a big bottleneck in large-scale scenario.

Looking forward to your reply!

Thanks, Jiangfei The same doubt, we are also reproducing the results of bert pre-training. According to the description, the 44-minute result is not optimized for communication allreduce, so the scalability of multiple machines should become a bottleneck. It is not clear how deepspeed solvedThis question

gongjingcs commented 3 years ago

@RezaYazdaniAminabadi Looking forward to your reply!

RezaYazdaniAminabadi commented 3 years ago

Hi @gongjingcs and @JF-D,

Sorry for the late reply. Here are my answers:

  1. Regarding the GPU architecture (SXM3 vs SXM2), I cannot see much of the difference. However, it seems the two GPUs have different peak power consumption, which means one should be able to run at a higher frequency.

2-3. As @JF-D mentioned, the memory structure can be very important as we are always bringing the minibatch data to be served on the GPU. We are using NVMe. If as we are seeing the GPU utilization cannot remain high and has a lot of fluctuation, I assume that is due to the latency of accessing data on the memory which stalls the GPU pipeline and causes a lot of idle time. In deepspeed, we try to prefetch the data by using an AsyncWorker (https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/data_worker.py). You probably need to specify a larger prefetching queue (we currently prefetch 3 requests), in order to hide the latency of accessing the main memory. We initialize the async-worker here: https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/bing_bert_dataset_provider.py#L84-L86

4 .Regarding the last question, we use the all-reduce from the torch distributed library, and since the GPUs are connected together through NVLinks (~300GB/s of bandwidth), there is less overhead due to the all-reduce. However, as you mentioned, this latency hits the end-to-end performance as we increase the number of nodes, which is due to having more frequent optimization steps (lower gradient accumulation). On the other hand, due to the efficient implementation of all-reduce, its communication latency is almost independent of the number of nodes and is only dependent on the message size.

Thanks, Reza

hxbloom commented 3 years ago

Hi @JF-D ,

-- update by enabling attention dropout checkpoint and normalize invertible in transformer kernel, I got 5% gain. But it's still far away from the deepspeed result. image

I tried both attention dropout checkpointing and normalize invertible. Here is my result: 1. attention dropout checkpoint has -2~-1% influence on the speed, normalize invertible has 1%~2% influence on the speed. I cannot get the 5% gain you mentioned.

Afaik, attention dropout checkpoint re-computes Forward() of attention dropout in the backward step. It's a time-memory trade-off.

Could you please show me the environment you got 5% gain? Looking forward to your reply.

Thanks, Dong

JF-D commented 3 years ago

Hi @JF-D ,

-- update by enabling attention dropout checkpoint and normalize invertible in transformer kernel, I got 5% gain. But it's still far away from the deepspeed result. image

I tried both attention dropout checkpointing and normalize invertible. Here is my result: 1. attention dropout checkpoint has -2~-1% influence on the speed, normalize invertible has 1%~2% influence on the speed. I cannot get the 5% gain you mentioned.

Afaik, attention dropout checkpoint re-computes Forward() of attention dropout in the backward step. It's a time-memory trade-off.

Could you please show me the environment you got 5% gain? Looking forward to your reply.

Thanks, Dong

Hi, @hxbloom Sorry for the late reply. The hardware environment you can find in the comments above. For pytorch, I use pytorch 1.5.

  1. I used attention dropout checkpointing and normalize invertible at the same time. It is possible that the implementation of this kernel will be faster. Besides, Re-computation is indeed a time-memory trade-off. But dropout is a memory-bandwidth limited operator. I think the re-computation of dropout may have little influence on the speed.
  2. If not, this may be a problem of my hardware environment. As I mentioned https://github.com/microsoft/DeepSpeed/issues/272#issuecomment-650133971, sometimes my GPU utilization is low. The speed I measured may not be very reliable.

Thanks, Jiangfei

hxbloom commented 3 years ago

Hi @JF-D ,

-- update by enabling attention dropout checkpoint and normalize invertible in transformer kernel, I got 5% gain. But it's still far away from the deepspeed result. image

I tried both attention dropout checkpointing and normalize invertible. Here is my result: 1. attention dropout checkpoint has -2~-1% influence on the speed, normalize invertible has 1%~2% influence on the speed. I cannot get the 5% gain you mentioned. Afaik, attention dropout checkpoint re-computes Forward() of attention dropout in the backward step. It's a time-memory trade-off. Could you please show me the environment you got 5% gain? Looking forward to your reply. Thanks, Dong

Hi, @hxbloom Sorry for the late reply. The hardware environment you can find in the comments above. For pytorch, I use pytorch 1.5.

  1. I used attention dropout checkpointing and normalize invertible at the same time. It is possible that the implementation of this kernel will be faster. Besides, Re-computation is indeed a time-memory trade-off. But dropout is a memory-bandwidth limited operator. I think the re-computation of dropout may have little influence on the speed.
  2. If not, this may be a problem of my hardware environment. As I mentioned #272 (comment), sometimes my GPU utilization is low. The speed I measured may not be very reliable.

Thanks, Jiangfei

Got it, Thanks for your reply!

Dong