Closed FRAOTIAC closed 3 years ago
Hi,
Thanks for reaching out. We have noticed this new feature from torch1.8 and we are working on it. The difficulty is how to keep consistent with the current APIs since this hook function requires an async Future handle returned. We will keep you updated.
Hi hang, That's great, very good to hear it. Now I am working on migrate DGC to DDP, the problem is that the grace's dgc implemention requiring two variables to pass around, indices and values. but Future can only accept one to call all_reduce. I assume that only the values needs to run all_reduce to get the mean value, so I put value inot Future, and put indices into states, update it after every compress and get it before decompress. It's that the right way to do it?
DGC doesn't support all_reduce to do the communication because the indices on each work are different. To follow the original comm_hook logic, you need to first concat values and indices into one single tensor, and then use all_gather to replace the all_reduce as shown in this example. The decompress function also needs to be modified according to the inputs.
Hi, Hang,
Thanks for the info, but I do not understand why to replace dist.all_reduce
with dist.all_gather
. I thought DGC would sparse the gradient and update the gradient by indices on each node after all_reduce. Also after DGC compress, each node should have different size of values_indices tensor, which can not all_gathered to a list tensor for it requires all tensor have the same number of elements.
dist.all_reduce
simply sums up the dense inputs across the nodes. It doesn't support value-index pair sparse tensor format. To perform allreduce for sparse tensors, you need to use allgather to collect all sparse tensors and cast them into dense format locally, then sum them up.
In grace, we have developed allgather for tensors with different length, please check here.
Hi, @hangxu0304 Thanks for the suggestion, after encountering several bugs in pytorch, I finally got it running, but the performance is poor (about 5x slower). I hope I got the code right, but it may not be the optimal solution. I will post my solution here, hope it can help in your later adapation work, and would appreciate for some performation optimization suggestion.
import torch
from torch import distributed as dist
import logging
import sys
from abc import ABC, abstractmethod
from grace_dl/dist/memory/dgc.py import DgcMemory
from grace_dl.dist import Compressor, Memory, debug
def allgather(tensors, world_size):
local_sizes = torch.tensor([t.numel() for t in tensors], device=tensors.device) # TODO: set device
gathered_sizes = [torch.empty_like(local_sizes) for _ in range(world_size)]
dist.all_gather(gathered_sizes, local_sizes) # tensor of tensor sizes per rank
tensors_gathered = []
for tensor, sizes in zip(tensors, zip(*gathered_sizes)):
# print(f'allgather i:{i}')
# i = i+1
local_size = tensor.numel()
max_size = max(sizes)
gathered = []
for _ in range(world_size):
padded = torch.empty(max_size, dtype=tensor.dtype, layout=tensor.layout, device=tensor.device)
gathered.append(padded)
if local_size != max_size:
padding = torch.empty(max_size - local_size, dtype=tensor.dtype, layout=tensor.layout,
device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(gathered, tensor)
data_list = []
for size, tensor_gathered in zip(sizes, gathered):
data_list.append(tensor_gathered[:size])
tensors_gathered.append(data_list)
return tensors_gathered
class DgcCompressor(Compressor):
def __init__(self, compress_ratio=0.3, world_size=1, momentum=0.9, gradient_clipping=False):
super().__init__(memory=DgcMemory(momentum, gradient_clipping, world_size), tensors_size_are_same=False)
self.compress_ratio = compress_ratio
def compress(self, tensor, name):
shape = tensor.size()
tensor = tensor.flatten()
numel = tensor.numel()
sample_shape = [max(1, int(numel * 0.01))]
sample_index = torch.empty(sample_shape).uniform_(0, numel).type(torch.long)
sample_tensor = tensor[sample_index]
k = max(1, int(numel * self.compress_ratio * 0.01))
vals, indices = torch.topk(sample_tensor.abs(), k)
thr = vals.min()
mask = tensor.abs() >= thr
selected = mask.sum()
for _ in range(10):
if selected > 1.3 * numel * self.compress_ratio:
thr = 1.3 * thr
elif selected < 0.7 * numel * self.compress_ratio:
thr = 0.7 * thr
else:
break
mask = tensor.abs() >= thr
selected = mask.sum()
indices, = torch.where(mask)
values = tensor[indices]
# tensor_compressed = values, indices
# concat values and indices into one single tensor for all_gather
tensor_compressed = torch.cat((values,indices), 0).reshape(2, -1)
ctx = shape, mask, numel
return tensor_compressed, ctx
def decompress(self, tensor_compressed, ctx):
shape, _, numel = ctx
## sum it up
vals, idxs = tensor_compressed[0],tensor_compressed[1]
i = 0
dense_res = []
for val in vals:
res = torch.zeros(numel, dtype=torch.float32, device="cuda")
res.scatter_(0, idxs[i].to(torch.int64), val)
dense_res.append(res)
i = i + 1
dense_tensor = torch.stack(dense_res, dim=0)
sum_all = torch.sum(dense_tensor, 0)
# print(f'sum_all.shape ={sum_all.shape}')
return sum_all.view(shape)
class DGCState(object):
__slots__ = [
"process_group",
# The two fields below are the hyperparameters that should be tuned by the user.
"use_error_feedback",
"error_dict",
"iter",
"start_iter",
"indices",
"warm_start",
"compress_ratio",
"sizes",
]
def __init__(
self,
process_group,
use_error_feedback=True,
warm_start=True,
random_seed=0,
start_iter=0,
compress_ratio=0.1,
):
logging.info(
"DGC config: ; "
"start_iter = {}; use_error_feedback = {}; warm_start = {}; compress_ratio = {}.".format(
start_iter,
use_error_feedback,
warm_start,
compress_ratio,
)
)
self.process_group = process_group
self.use_error_feedback = use_error_feedback
self.warm_start = warm_start
self.start_iter = start_iter
self.compress_ratio = compress_ratio
self.iter = 0
def maybe_increase_iter(self, bucket):
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.get_index() == 0:
self.iter += 1
if self.iter == self.start_iter:
logging.info(
"Start to apply DGC after {} iterations.".format(self.iter)
)
def dgc_compress_hook(state: DGCState, bucket: dist._GradBucket) -> torch.futures.Future:
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
device = input_tensor.device
dtype = input_tensor.dtype
dgc = DgcCompressor(state.compress_ratio)
compressed_tensor, ctx = dgc.compress(input_tensor, "dgc_c")
local_sizes = torch.tensor([t.numel() for t in compressed_tensor]).cuda() # TODO: set device
gathered_sizes = [torch.empty_like(local_sizes) for _ in range(world_size)]
gather_size_fut = dist.all_gather(gathered_sizes, local_sizes, group=group_to_use, async_op=True).get_future()
gather_size_fut = torch.futures.collect_all([gather_size_fut])
def diff_size_all_gather(value,sizes):
#debug()
value_tensor, value_size = value,sizes
local_size = value_tensor.numel()
max_size = max(value_size)
gathered = []
#debug()
for _ in range(world_size):
padded = torch.empty(max_size, dtype=value_tensor.dtype, layout=value_tensor.layout, device=value_tensor.device)
gathered.append(padded)
if local_size != max_size:
padding = torch.empty(max_size - local_size, dtype=value_tensor.dtype, layout=value_tensor.layout,
device=value_tensor.device)
value_tensor = torch.cat((value_tensor, padding), dim=0)
#debug()
return dist.all_gather(gathered, value_tensor, group=group_to_use, async_op=True).get_future()
def get_value_indices_tensor_and_size(fut):
fut_list = fut.wait()
gather_sizes = fut_list[0].wait()[0]
# #debug(gather_sizes)
gather_sizes = list(zip(*gather_sizes))
value_tensor, index_tensor, value_size = compressed_tensor[0], compressed_tensor[1], gather_sizes[0]
state.sizes = value_size
fut0 = diff_size_all_gather(value_tensor, value_size)
fut1 = diff_size_all_gather(index_tensor, value_size)
fut_all = torch.futures.collect_all([fut0, fut1])
#debug()
return fut_all
def decompress(fut):
fut_list = fut.wait()
fut_list = fut_list.value()
values = fut_list[0].wait()[0]
indices = fut_list[1].wait()[0]
tensors_gathered = []
data_list = []
for size, tensor_gathered in zip(state.sizes, values):
data_list.append(tensor_gathered[:size])
tensors_gathered.append(data_list)
data_list = []
for size, tensor_gathered in zip(state.sizes, indices):
data_list.append(tensor_gathered[:size])
tensors_gathered.append(data_list)
decompressed_tensor = bucket.get_tensors()[0]
decompressed_tensor.copy_(dgc.decompress(tensors_gathered, ctx))
return [decompressed_tensor]
return gather_size_fut.then(get_value_indices_tensor_and_size).then(decompress)
Thank you very much. Your implementation is really good. I do have the following 2 suggestions for your optimization:
DgcCompressor.decompress()
, there is a For Loop
to cast the sparse tensor into dense format, which can be very expensive in case of large gradients. You may want to use this api scatter_add to gain some speed.Regarding the poor performance, are you comparing with GRACE DGC or the no compression baseline? And also please note that, gradient compression is not always beneficial due to various model architectures, network conditions, and different number of nodes. Could you please specify your testing environment?
@FRAOTIAC Could you provide a usage example of your implementation? I would love to use it. Do I simply need to pass the state and compress hook to DDP's register_comm_hook
?
If you optimized the code in the mean time or identified any bugs, I would also appreciate an receiving the updated version.
Hi, PyTorch 1.8 have this new hook
torch.nn.parallel.DistributedDataParallel.register_comm_hook()
, any advices on how to integrate grace into ddp using the dist examples?