Closed jasperzhong closed 3 years ago
跑了下Elastic的benchmark和Static(没有elastic)的benchmark. 差距还是挺大的.
https://github.com/horovod/horovod/blob/master/examples/elastic/pytorch/pytorch_synthetic_benchmark_elastic.py https://github.com/horovod/horovod/blob/master/examples/pytorch/pytorch_synthetic_benchmark.py
而且optimizer也会有影响.
benchmark环境: 4个A100, resnet50, pytorch, bs=32
commit interval | Optimizer | Throughput (samples /s) | |
---|---|---|---|
Static | - | SGD | 1592.3 |
Static | - | SGDM | 1527.5 |
Elastic | 1 | SGD | 1270.4 |
Elastic | 2 | SGD | 1403.0 |
Elastic | 4 | SGD | 1496.6 |
Elastic | 8 | SGD | 1545.5 |
Elastic | 1 | SGDM | 1156.4 |
Elastic | 2 | SGDM | 1314.1 |
Elastic | 4 | SGDM | 1413.8 |
Elastic | 8 | SGDM | 1475.7 |
这说明:
从文档来看,commit的数据是存放在memory中.
Periodically call state.commit() to backup a copy of your state in memory.
commit的目的是防止更新到一半出现故障导致corrupted state.
This is useful to prevent corrupted state in the event that a worker fails unexpectedly. For example, if training fails in the middle of a parameter update, some gradient updates may have applied while others were still being allreduced. When this happens, a HorovodInternalError will be raised, and all parameters will be restored to the values at the time of the last commit.
为了避免频繁的commit,官方的建议是:
In general, if your hardware is generally reliable, and your orchestration system gives the driver ample warning when a host is scheduled to be removed from the job, then you can safely call state.commit() on a reduced frequency, and call state.check_host_updates() at the end of each batch instead.
Elastic Horovod提供的API非常简单,只需要向脚本添加几行代码,就可以实现Elastic功能.
import torch
import horovod.torch as hvd
hvd.init()
torch.cuda.set_device(hvd.local_rank())
dataset = ...
model = ...
optimizer = optim.SGD(model.parameters(), lr * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)
@hvd.elastic.run
def train(state):
batch_offset = state.batch
for state.epoch in range(state.epoch, epochs):
for state.batch in range(state.batch, batches_per_epoch):
data, target = get_random_batch()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if state.batch % batches_per_commit == 0:
state.commit()
state.batch = 0
def on_state_reset():
# adjust learning rate on reset
for param_group in optimizer.param_groups:
param_group['lr'] = lr * hvd.size()
state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset])
train(state)
这里面的重点是state.
先来研究下horovod.elastic.TorchState
这个函数.
研究了代码,发现 model state, optimizer state都是保存在GPU memory而不是CPU memory.
相关实现代码如下:
这样速度相比copy到CPU memory肯定会更快,但是没想到开销还是会这么大?有点出乎我的意料.
于是profile了下系统. avg save time = 0.01539 s. 每10个iteration调用11次commit,save time时间为0.16929s. 而原本的10个iteration time = 0.804s,做了elastic后,10个iteration time = 1.0085s,相差0.2045s.
看来save time确实是占据了主要的overhead,达到82.8%.
另外可以根据带宽和数据量估算下save time. 根据NVIDIA A100官方specs,A100 bandwidth有1.6TB/s. 保存量为模型大小,为98MB,按照数据量/带宽,时间应为6.125e-05 s才对,怎么会高达0.01539s? 这都是250x了????
简单benchmark了一下. 代码如下.
结果如下
看来deepcopy model非常低效!deepcopy tensor这个和估计值基本差不多.
如果这部分overhead能降低100倍,从82.8%直接降到0.828%,那基本可以忽略不计.
第二个问题是,另外的17.2%的开销是什么呢?是check_host_updates
.
这个函数是检测是否有新的worker加入或者移除. 具体是如何通知的,代码不在这一块. 需要看runner/elastic部分的代码.
所以总结一下,Elastic Horovod会使得性能下降20%. 其中,超过80%是copy in memory开销,其余是由于同步timestamp信息导致的开销(需要一次broadcast操作).
Elastic Horovod的核心逻辑在common/elastic.py.
这个实现挺简洁的. 通过raise error来实现对failure和add/remove情况的处理.
HorovodInternalError
,需要restore stateHorovodUpdatedInterrupt
,不需要restore state. 前者是Error,后者是interrupt,从命名中也可以看出区别. HorovodInternalError
全部是MPI操作失败后raise.
HorovodUpdatedInterrupt
HorovodUpdatedInterrupt
来自worker的add/remove.
值得注意的是这个notification_manager. 实现在runner/elastic/worker.py. 看上去是注册了一个service,接收HostUpdatedRequest通知. 而发通知的逻辑在Elastic Driver.
可以看到,每个worker都注册了一个WorkerNotificationClient.
Elastic Horovod需要提供一个host discovery的脚本,检测host是否发生变化. 如果有add/remove,就会通知client,发送一个HostsUpdatedRequest,包括(timestamp, update_res).
timestamp就是物理时间(取整), update_res是no_update(0)/remove(1)/add(2)这几种情况取or.
现在就能理解check_host_updates这段逻辑. 如果收到的消息timestamp更新,就会触发HostsUpdatedInterrupt
.
https://horovod.readthedocs.io/en/stable/elastic_include.html