jasperzhong / cs-notes

CS认知体系
6 stars 0 forks source link

learn FairScale #21

Closed jasperzhong closed 3 years ago

jasperzhong commented 3 years ago

https://github.com/facebookresearch/fairscale

FairScale is a PyTorch extension library for high performance and large scale training.

看上去是几个module的并集.

jasperzhong commented 3 years ago

image

文档里这张图有点意思. 不过too slow的solution里面看上去有点牵强,activation checkpoint, zero主要是节约memory的方法,应该是说可以增大batch size,然后可能可以提高速度.

jasperzhong commented 3 years ago

https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html

介绍了ZeRO,分别是ZeRO-1, ZeRO-2, ZeRO-3.

Optimizer State Sharding (OSS)

比如Adam做mixed precision training,还要保存一份FP32的momentum和variance. 这挺大的,而且每个worker都有一份,很多冗余. OSS的idea就是把这些FP32的optimizer states shard到各个worker上去,每个worker all-reduce gradients之后,只需要update自己那一份optimizer state,更新自己那一部分parameters,最后把自己这个shard对应的参数做一个broadcast/allgather.

分析下额外的通信量,假如有N个worker,模型参数量为M,每个worker上的shard为M/N,因OSS的broadcast/allgather操作产生M/N (N - 1)大小的通信量. 原本的all-reduce的通信量是 2 M (N-1) / N. 所以OSS带来1.5x通信开销.

文档说

  1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending on the optimizer being used

always faster???are you sure??于是我去跑了下,发现居然是真的....用了自带的benchmark脚本. https://github.com/facebookresearch/fairscale/blob/master/benchmarks/oss.py

脚本默认是RMSprop. RMSprop是带个variance作为optimizer state的. 模型是ResNet101,local batch size 256. 在net-g14上测试. Optimizer Median Throughput (img/s) (rank 0) Peak Memory (MB)
Vanilla SGD 1795.73 +- 11.79 1473.8
OSS + DDP 1903.18 +- 25.52 1209.1
OSS + ShardedDDP 1636.11 +- 7.92 927.6
如果用SGDM呢? Optimizer Median Throughput (img/s) (rank 0) Peak Memory (MB)
Vanilla SGD 1977.22 +- 8.64 1303.0
OSS + DDP 1999.37 +- 14.83 1172.6
OSS + ShardedDDP 1703.99 +- 8.99 890.7

这是为什么??因为update时间比broadcast时间还长吗??啊?震惊. 我得去看下代码....

Optimizer + Gradient State Sharding

就是在OSS基础上,觉得aggregated gradient也是冗余. 在反向传播的时候,worker把gradient reduce到那个shard对应的rank上,而不是做all-reduce. 每个rank update自己的参数和optimizer state. 最后做一个broadcast.

emmm. 这个方法看上去挺好的,不过为啥会慢这么多?还是得看代码实现.

Optimizer + Gradient + Horizontal Model Sharding

也叫做Fully Sharded Data Parallel (FSDP). 到目前为止,只剩下模型参数这一个冗余了. FSDP就是要把模型参数也做shard. FW的时候都要all-gather各层的parameter. 反向的时候还要all-gather各层的parameter(算梯度也需要原参数).

比较奇怪的地方是sync gradient用的是reduce scatter. 直接和ZeRO2一样用reduce不就行了吗?

别看需要两次额外的all-gather,但这个all-gather是可以和计算overlap的.

FSDP还支持offload parameter和gradient. 就是ZeRO-offload的方法了.

这段话引起了我的注意. pointwise optimizer vs non-pointwise optimizer. 有意思!

Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.


突然发现有这么一个模块torch.optim._multi_tesnor. 看了一下,就是原来是写个for循环,逐个param做update(weight decay, momentum, ...),现在是全部一起做各个操作.

这个没有faithfully实现ZeRO的idea. ZeRO-1和ZeRO-2是用的reduce-scatter + allgather而不是all-reduce/reduce + allgather.

jasperzhong commented 3 years ago

https://fairscale.readthedocs.io/en/latest/deep_dive/offload.html

OffloadModel

介绍了ZeRO-offload. 具体见 https://github.com/vycezhong/read-papers/issues/193

看文档offload似乎有两种用法:

  1. 一种就是ZeRO-offload,offload的是参数.
  2. 另一外一种似乎是ZeRO activation offload,就是如果使用了activation checkpoint,可以offload checkpoint的activation.
jasperzhong commented 3 years ago

https://fairscale.readthedocs.io/en/latest/deep_dive/adascale.html

AdaScale

今年OSDI '21 best paper Pollux https://github.com/vycezhong/read-papers/issues/180 钦定了AdaScale https://github.com/vycezhong/read-papers/issues/149

AdaScale是一种自动adapt lr的方法,适用于large batch和变batch size的情况. 原论文讨论的是SGD. 但Pollux似乎已经应用于Adam了. 所以文档提了一句:

This technique typically works well for SGD (with and without momentum) The assumption is that you already have a good learning rate schedule that works well for small batch sizes. (AdaScale has not been validated to work effectively with Adam, further research in that direction is needed.)

这篇文档讲的确实不错. 加深了我对AdaScale理解. 仔细参悟这段话

Adascale uses the concept of gain ratio which is intuitively a measure of how much the variance has reduced by averaging N small batch gradients. It is a quantity between 1 and N. In practice, the implementation tracks estimates of the gradient variance and norm-squared which are smoothed using an exponentially-weighted moving average. If T is the number of steps used to train the original small batch size before scaling, Adascale stops training once the accumulated gain ratio is greater than T. As you use more and more GPUs in the training the total steps needed to train decreases, but due to the value of gain ratio between [1, N], the total steps does not linearly decrease as you increase the GPUs. Additional training steps are taken to maintain the model accuracy, when compared with original_total_step/N (i.e. linear scaling). In other words, whenever the gain ratio is less than N, we could not take as large a step as we may have hoped for, and so the total number of iterations ends up being larger than T / N.

这个很有必要看下实现.

jasperzhong commented 3 years ago

https://fairscale.readthedocs.io/en/latest/deep_dive/pipeline_parallelism.html

Pipeline Parallelism

实现的是GPipe. 应该是来自https://github.com/kakaobrain/torchgpipe . 目前已经被整合进上游. https://pytorch.org/docs/stable/pipeline.html

jasperzhong commented 3 years ago

https://fairscale.readthedocs.io/en/latest/deep_dive/activation_checkpointing.html

Enhanced Activation Checkpointing

pytorch本身已经支持了activation checkpointing. fairscale enhance的地方是可以把这个checkpoint的activation offload到CPU,并可以handle forward函数的non-tensor输出(这有啥用?)

这个caveat不错,以前没注意到.

When using BatchNormalization you may need to freeze the calculation of statistics since we run the forward pass twice.

jasperzhong commented 3 years ago

总结一下,fairscale可以说是large-scale training的缉书. 虽然代码基本是抄的,但fairscale可以说是总结了目前学界在large-scale training领域的重要成果.

fairscale.nn.pipe is forked from torchgpipe, Copyright 2019, Kakao Brain, licensed under Apache License. fairscale.nn.model_parallel is forked from Megatron-LM, Copyright 2020, NVIDIA CORPORATION, licensed under Apache License. fairscale.optim.adascale is forked from AdaptDL, Copyright 2020, Petuum, Inc., licensed under Apache License. fairscale.nn.misc.flatten_params_wrapper is forked from PyTorch-Reparam-Module, Copyright 2018, Tongzhou Wang, licensed under MIT License.

jasperzhong commented 3 years ago

对ZeRO方法没什么兴趣. 先run pipeline parallelism. 打算跑一个8-stage的BERT-32.

可惜了. 看样子还不支持inter-node pipeline.

好像直接看pytorch自己的文档更好....

https://pytorch.org/docs/stable/pipeline.html

jasperzhong commented 3 years ago

暂时到此为止. 现在需要去深度了解Megatron-LM. 毕竟NVIDIA是真的用这个框架训练过大模型的!研究这个会更有意义一些.

jasperzhong commented 2 years ago

准备看看MoE实现.

https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/moe/moe_layer.py

data processing

注意这是GPT类型的训练. 输入数据是[T, B]. T是sequence length,B是batch size (how many sequences in a batch).

https://github.com/facebookresearch/fairscale/blob/main/benchmarks/datasets/wikitext2_data.py#L49-L51

具体处理的时候,就是把文档里面所有的tokens拼成一个列表. 取连续的T个tokens作为一个sequence,然后取B个sequences作为一个batch. 所以输入数据维度是[T, B].

因为是autoregressive训练,所以输入是batch[0, T] 标签是batch[1, T+1]. 所以一个hack方法就是取数据的时候使用的是sequence length + 1. 这样在训练的时候,很容易就拆分成了两个sequence length的输入和标签.

所以模型输入的维度是[T, B, H]