sands-lab / grace

GRACE - GRAdient ComprEssion for distributed deep learning
https://sands.kaust.edu.sa/project/grace/
BSD 2-Clause "Simplified" License
138 stars 44 forks source link

Seeking suggestions for embedding into ddp #17

Closed FRAOTIAC closed 3 years ago

FRAOTIAC commented 3 years ago

Hi, PyTorch 1.8 have this new hook torch.nn.parallel.DistributedDataParallel.register_comm_hook(), any advices on how to integrate grace into ddp using the dist examples?

hangxu0304 commented 3 years ago

Hi,

Thanks for reaching out. We have noticed this new feature from torch1.8 and we are working on it. The difficulty is how to keep consistent with the current APIs since this hook function requires an async Future handle returned. We will keep you updated.

FRAOTIAC commented 3 years ago

Hi hang, That's great, very good to hear it. Now I am working on migrate DGC to DDP, the problem is that the grace's dgc implemention requiring two variables to pass around, indices and values. but Future can only accept one to call all_reduce. I assume that only the values needs to run all_reduce to get the mean value, so I put value inot Future, and put indices into states, update it after every compress and get it before decompress. It's that the right way to do it?

hangxu0304 commented 3 years ago

DGC doesn't support all_reduce to do the communication because the indices on each work are different. To follow the original comm_hook logic, you need to first concat values and indices into one single tensor, and then use all_gather to replace the all_reduce as shown in this example. The decompress function also needs to be modified according to the inputs.

FRAOTIAC commented 3 years ago

Hi, Hang, Thanks for the info, but I do not understand why to replace dist.all_reduce with dist.all_gather. I thought DGC would sparse the gradient and update the gradient by indices on each node after all_reduce. Also after DGC compress, each node should have different size of values_indices tensor, which can not all_gathered to a list tensor for it requires all tensor have the same number of elements.

hangxu0304 commented 3 years ago

dist.all_reduce simply sums up the dense inputs across the nodes. It doesn't support value-index pair sparse tensor format. To perform allreduce for sparse tensors, you need to use allgather to collect all sparse tensors and cast them into dense format locally, then sum them up. In grace, we have developed allgather for tensors with different length, please check here.

FRAOTIAC commented 3 years ago

Hi, @hangxu0304 Thanks for the suggestion, after encountering several bugs in pytorch, I finally got it running, but the performance is poor (about 5x slower). I hope I got the code right, but it may not be the optimal solution. I will post my solution here, hope it can help in your later adapation work, and would appreciate for some performation optimization suggestion.

     import torch
    from torch import distributed as dist
    import logging
    import sys
    from abc import ABC, abstractmethod
    from grace_dl/dist/memory/dgc.py import DgcMemory
    from grace_dl.dist import Compressor, Memory, debug

def allgather(tensors, world_size):
    local_sizes = torch.tensor([t.numel() for t in tensors], device=tensors.device)  # TODO: set device
    gathered_sizes = [torch.empty_like(local_sizes) for _ in range(world_size)]
    dist.all_gather(gathered_sizes, local_sizes)  # tensor of tensor sizes per rank
    tensors_gathered = []

    for tensor, sizes in zip(tensors, zip(*gathered_sizes)):
        # print(f'allgather i:{i}')
        # i = i+1
        local_size = tensor.numel()
        max_size = max(sizes)
        gathered = []
        for _ in range(world_size):
            padded = torch.empty(max_size, dtype=tensor.dtype, layout=tensor.layout, device=tensor.device)
            gathered.append(padded)
        if local_size != max_size:
            padding = torch.empty(max_size - local_size, dtype=tensor.dtype, layout=tensor.layout,
                                  device=tensor.device)
            tensor = torch.cat((tensor, padding), dim=0)
        dist.all_gather(gathered, tensor)

        data_list = []
        for size, tensor_gathered in zip(sizes, gathered):
            data_list.append(tensor_gathered[:size])

        tensors_gathered.append(data_list)

    return tensors_gathered

class DgcCompressor(Compressor):

    def __init__(self, compress_ratio=0.3, world_size=1, momentum=0.9, gradient_clipping=False):
        super().__init__(memory=DgcMemory(momentum, gradient_clipping, world_size), tensors_size_are_same=False)
        self.compress_ratio = compress_ratio

    def compress(self, tensor, name):
        shape = tensor.size()
        tensor = tensor.flatten()
        numel = tensor.numel()

        sample_shape = [max(1, int(numel * 0.01))]
        sample_index = torch.empty(sample_shape).uniform_(0, numel).type(torch.long)
        sample_tensor = tensor[sample_index]

        k = max(1, int(numel * self.compress_ratio * 0.01))
        vals, indices = torch.topk(sample_tensor.abs(), k)

        thr = vals.min()
        mask = tensor.abs() >= thr
        selected = mask.sum()

        for _ in range(10):
            if selected > 1.3 * numel * self.compress_ratio:
                thr = 1.3 * thr
            elif selected < 0.7 * numel * self.compress_ratio:
                thr = 0.7 * thr
            else:
                break
            mask = tensor.abs() >= thr
            selected = mask.sum()

        indices, = torch.where(mask)
        values = tensor[indices]
        # tensor_compressed = values, indices
        # concat values and indices into one single tensor for all_gather
        tensor_compressed = torch.cat((values,indices), 0).reshape(2, -1)
        ctx = shape, mask, numel
        return tensor_compressed, ctx

    def decompress(self, tensor_compressed, ctx):
        shape, _, numel = ctx
        ## sum it up
        vals, idxs = tensor_compressed[0],tensor_compressed[1]
        i = 0
        dense_res = []
        for val in vals:
            res = torch.zeros(numel, dtype=torch.float32, device="cuda")
            res.scatter_(0, idxs[i].to(torch.int64), val)
            dense_res.append(res)
            i = i + 1

        dense_tensor = torch.stack(dense_res, dim=0)

        sum_all = torch.sum(dense_tensor, 0)
        # print(f'sum_all.shape ={sum_all.shape}')
        return sum_all.view(shape)

class DGCState(object):
    __slots__ = [
        "process_group",
        # The two fields below are the hyperparameters that should be tuned by the user.
        "use_error_feedback",
        "error_dict",
        "iter",
        "start_iter",
        "indices",
        "warm_start",
        "compress_ratio",
        "sizes",
    ]

    def __init__(
            self,
            process_group,
            use_error_feedback=True,
            warm_start=True,
            random_seed=0,
            start_iter=0,
            compress_ratio=0.1,
    ):
        logging.info(
            "DGC config: ; "
            "start_iter = {}; use_error_feedback = {}; warm_start = {}; compress_ratio = {}.".format(
                start_iter,
                use_error_feedback,
                warm_start,
                compress_ratio,
            )
        )
        self.process_group = process_group
        self.use_error_feedback = use_error_feedback
        self.warm_start = warm_start
        self.start_iter = start_iter
        self.compress_ratio = compress_ratio
        self.iter = 0

    def maybe_increase_iter(self, bucket):
        # Since bucket 0 is the last bucket to allreduce in an iteration.
        # Only increase `iter` when bucket 0 is processed.
        if bucket.get_index() == 0:
            self.iter += 1

        if self.iter == self.start_iter:
            logging.info(
                "Start to apply DGC after {} iterations.".format(self.iter)
            )

def dgc_compress_hook(state: DGCState, bucket: dist._GradBucket) -> torch.futures.Future:
    process_group = state.process_group
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = group_to_use.size()
    # The input tensor is a flattened 1D tensor.
    input_tensor = bucket.get_tensors()[0]
    device = input_tensor.device
    dtype = input_tensor.dtype
    dgc = DgcCompressor(state.compress_ratio)
    compressed_tensor, ctx = dgc.compress(input_tensor, "dgc_c")

    local_sizes = torch.tensor([t.numel() for t in compressed_tensor]).cuda() # TODO: set device
    gathered_sizes = [torch.empty_like(local_sizes) for _ in range(world_size)]

    gather_size_fut = dist.all_gather(gathered_sizes, local_sizes, group=group_to_use, async_op=True).get_future()
    gather_size_fut = torch.futures.collect_all([gather_size_fut])

    def diff_size_all_gather(value,sizes):
        #debug()
        value_tensor, value_size = value,sizes
        local_size = value_tensor.numel()
        max_size = max(value_size)
        gathered = []
        #debug()
        for _ in range(world_size):
            padded = torch.empty(max_size, dtype=value_tensor.dtype, layout=value_tensor.layout, device=value_tensor.device)
            gathered.append(padded)
        if local_size != max_size:
            padding = torch.empty(max_size - local_size, dtype=value_tensor.dtype, layout=value_tensor.layout,
                                  device=value_tensor.device)
            value_tensor = torch.cat((value_tensor, padding), dim=0)
        #debug()

        return dist.all_gather(gathered, value_tensor, group=group_to_use, async_op=True).get_future()

    def get_value_indices_tensor_and_size(fut):
        fut_list = fut.wait()
        gather_sizes = fut_list[0].wait()[0]

        # #debug(gather_sizes)
        gather_sizes = list(zip(*gather_sizes))

        value_tensor, index_tensor, value_size = compressed_tensor[0], compressed_tensor[1], gather_sizes[0]
        state.sizes = value_size
        fut0 = diff_size_all_gather(value_tensor, value_size)
        fut1 = diff_size_all_gather(index_tensor, value_size)
        fut_all = torch.futures.collect_all([fut0, fut1])
        #debug()
        return fut_all

    def decompress(fut):
        fut_list = fut.wait()
        fut_list = fut_list.value()

        values = fut_list[0].wait()[0]
        indices = fut_list[1].wait()[0]

        tensors_gathered = []
        data_list = []
        for size, tensor_gathered in zip(state.sizes, values):
            data_list.append(tensor_gathered[:size])
        tensors_gathered.append(data_list)

        data_list = []
        for size, tensor_gathered in zip(state.sizes, indices):
            data_list.append(tensor_gathered[:size])

        tensors_gathered.append(data_list)

        decompressed_tensor = bucket.get_tensors()[0]
        decompressed_tensor.copy_(dgc.decompress(tensors_gathered, ctx))
        return [decompressed_tensor]

    return gather_size_fut.then(get_value_indices_tensor_and_size).then(decompress)
hangxu0304 commented 3 years ago

Thank you very much. Your implementation is really good. I do have the following 2 suggestions for your optimization:

Regarding the poor performance, are you comparing with GRACE DGC or the no compression baseline? And also please note that, gradient compression is not always beneficial due to various model architectures, network conditions, and different number of nodes. Could you please specify your testing environment?

BraSDon commented 10 months ago

@FRAOTIAC Could you provide a usage example of your implementation? I would love to use it. Do I simply need to pass the state and compress hook to DDP's register_comm_hook?

If you optimized the code in the mean time or identified any bugs, I would also appreciate an receiving the updated version.