Leinao / GPUCluster

用于讨论在使用https://bitahub.ustc.edu.cn 过程中碰到的问题
MIT License
2 stars 0 forks source link

deepspeed显存测量实验 #6

Open Ruanwenna opened 1 year ago

Ruanwenna commented 1 year ago

Dockerfile:

FROM huggingface/transformers-pytorch-gpu:latest
RUN pip install deepspeed -i https://pypi.tuna.tsinghua.edu.cn/simple

代码来源: https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat

下载相关包:

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-Chat/
pip install -r requirements.txt

启动命令:

#进入第一步监督微调脚本文件所在目录
cd DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts
#选择single_node配置方式运行相应模型脚本(例如运行一个1.3b的模型)
bash single_node/run_1.3b.sh

调参数:

#进入第一步监督微调脚本文件所在目录
cd DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts
#修改脚本文件中的micro_batch 和max_seq_len(例如运行一个1.3b的模型)
vim single_node/run_1.3b.sh
#调整allgather_bucket_size和reduce_bucket_size,先进入training文件夹
cd DeepSpeedExamples/applications/DeepSpeed-Chat/training
#在ds_utils.py中找到zero的配置,加入allgather_bucket_size和reduce_bucket_size的配置信息(例如"allgather_bucket_size": 2e8)
vim utils/ds_utils.py

实验结果:

model gpu_num zero_stage micro-batches max-seq-len time(10steps) MaxMemAllocated(G)
opt-1.3b 7(rtx3090) 2 4 512 16s 15.08
opt-1.3b 7(rtx3090) 3 4 512 22s 15.45
opt-1.3b 7(rtx3090) 2 2 256 14s 7.51
opt-1.3b 7(rtx3090) 2 3 256 21s 8.06
opt-1.3b 7(rtx3090) 2 2 128 14s 6.18
opt-1.3b 7(rtx3090) 2 3 128 21s 7.14
opt-1.3b 7(rtx3090) 3 8 512 --- OOM
opt-1.3b 8(1080ti) 2 2 128 14s 6.46
opt-1.3b 4(rtx3090) 3 4 512 14s 17.57
opt-6.7b 7(rtx3090) 3 2 256 1m53s 20.54
opt-6.7b 7(rtx3090) 3 2 128 1m53s 20.52
opt-6.7b 7(rtx3090) 2 2 128 --- OOM

实验结果分析: 在opt-1.3b模型的训练过程中降低micro-batches和max_seq_len可以有效降低显存消耗,而opt-6.7b模型中降低这两个参数没有明显的显存变化。猜测在大模型训练过程中deepspeed根据显存消耗的情况开启了一些显存优化程序,例如激活值重计算。

Ruanwenna commented 1 year ago

代码更改:

#进入第一步监督微调的main.py文件所在目录
cd DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning
#将main.py中的“assert not args.offload, "zero-offload is not currently supported but coming soon!"”注释掉
vim main.py
#进入脚本文件中将--offload配置加上
vim single_node/run_1.3b.sh

   --offload \
   --zero_stage $ZERO_STAGE \
   --deepspeed \

实验结果:

model gpu_num zero_stage micro-batches max-seq-len time(10steps) MaxMemAllocated(G) MemAllocated(G)
opt-1.3b 7 3+offload 2 256 2m510s 4.68 0.44
opt-1.3b 4(1080ti) 3+offload 2 128 4m42s 3.76 0.41
opt-1.3b 2(1080ti) 3+offload 2 128 --- OOM OOM
opt-6.7b 7 3+offload 2 256 5m7s 2.08 0.44
opt-6.7b 4 3+offload 2 128 4m55s 2.08 0.41
opt-6.7b 4 2+offload 2 128 2m30s 15.12 12.85
opt-6.7b 2 3+offload 2 256 2m42s 2.37 0.44
opt-6.7b 2 2+offload 2 256 1m43s 15.52 12.83
opt-6.7b 2(bitahub) 3+offload 2 128 --- OOM OOM
opt-6.7b 1 3+offload 2 256 3m35s 3.18 0.44
opt-6.7b 1 3+offload 8 512 3m50s 4.52 0.77
opt-6.7b 4(v100) 3+offload_opt 2 256 11m 5.31 3.46
opt-6.7b 2(v100) 3+offload_opt 2 256 --- OOM OOM
opt-13b 7 3+offload 2 256 7m56s 2.38 0.44
opt-13b 7 2+offload 2 256 --- OOM OOM
opt-13b 4 3+offload 2 256 6m 2.4 0.44
opt-13b 2 3+offload 2 256 6m 2.86 0.44
opt-13b 1 3+offload 2 256 7m59s 3.76 0.44

实验分析:

zero的内存估算程序:

from transformers import AutoModel
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live

## specify the model you want to train on your device
model = AutoModel.from_pretrained("facebook/opt-13b")

## estimate the memory cost (both CPU and GPU)
#可以通过更改节点数量和单节点的GPU数量来做预估
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=7, num_nodes=1)
estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=7, num_nodes=1)

上述程序的估算具体实现代码链接: zero3的估算 https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py zero2和zero1的估算 https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py

Ruanwenna commented 1 year ago

更改zero配置:

#进入到utils中更改zero的配置,
cd DeepSpeedExamples/applications/DeepSpeed-Chat/training
#进入ds_utils.py将 "offload_param"中的"device"改成none时表示将参数保留在GPU中
vim ds_utils.py

batch size 和seq_len对显存影响实验 实验结果:

model offload_param micro-batches max-seq-len Memalloc MaxMemAllocated(G) time(10steps)
opt-6.7b none 8 512 2.6 7.77 6m10s
opt-6.7b cpu 8 512 0.77 3.93 5m26s
opt-6.7b none 16 512 2.99 9.16 5m52s
opt-6.7b cpu 16 512 1.16 7.34 5m49s
opt-6.7b none 32 512 3.75 13.5 6m48s
opt-6.7b cpu 32 512 1.93 11.67 7m
opt-13b none 2 256 3.93 5.87 8m31s
opt-13b cpu 2 256 0.44 2.38 8m3s
opt-13b none 8 512 4.26 11.7 9m8s
opt-13b cpu 8 512 0.77 8.21 8m52s
opt-13b none 16 512 4.65 11.99 9m41s
opt-13b cpu 16 512 1.16 8.5 9m14s
opt-13b none 32 512 5.42 18.79 12m10s
opt-13b none 32 512 1.93 15.3 11m11s

实验分析:

  1. 6.7b: micro_batch_size从16->32时memalloc增长0.77,maxmemalloc增长4.33 8->16 memalloc增长0.39 maxmemalloc增长3.41
  2. 13b:16->32时 memalloc增长0.77 maxmemalloc增长6.8 8->16 memallo增长0.39 maxmemalloc增长0.29(loss scale是否影响内存占用) 2 256->8 512 memalloc增长0.33 maxmemalloc增长5.83
  3. memalloc只与数据相关即只与micro_batch_size 和 max_seq_len相关
  4. maxmemalloc与数据和参数量都相关:maxmemalloc应该包括梯度缓冲区,激活值内存,参数内存,数据内存
    • 将参数卸载关闭没有带来训练速度的提升,相较同样配置的卸载反而变慢,与预期不符
    • 当batch_size成倍增长时每10个step的运行时间有少量的增长,而step数量在batch_size成倍增长后会成倍下降,因此在batch_size增加后总的训练时间将显著地减少

进程虚拟内存使用情况实验结果

model offload_param micro-batches max-seq-len 初始化完成时 训练开始后 初始化完成cpu内存占用
opt-13b none 2 256 97.2 102.5 231
opt-13b CPU 2 256 86.2 92.1 280
opt-6.7b none 8 512 78.5 85.1 150
opt-6.7b CPU 8 512 72.4 79.2 177
opt-1.3b none 2 256 54.5 60.7 58
opt-1.3b CPU 2 256 52.9 60.5 67
opt-1.3b none 8 512 58.1 79.2 70.6

实验分析: