sands-lab / grace

GRACE - GRAdient ComprEssion for distributed deep learning
https://sands.kaust.edu.sa/project/grace/
BSD 2-Clause "Simplified" License
133 stars 45 forks source link

DGC GPU usage in GRACE #22

Closed Snoeprol closed 2 years ago

Snoeprol commented 2 years ago

Hi,

Thanks for making this awesome library, it really makes things easy to understand. However, there are some performance issues I have encountered. When using DGC, the GPU utilization (on WANDB) is only 20%. I was wondering what the cause of this could be, as on some techniques I have implemented I get over 75% usage. As a consequence training times increase many-fold, which is obviously undesired.

image

I think I will personally also take a deeper dive into this issue as well, as I'm currently comparing DGC to other methods. If anybody has some tips on where possible bottlenecks are let me know.

Batch size is definitely large enough as other methods using 128 batch size have a high usage. I thought maybe there is some tensor which is located on CPU?

Best,

Mario

hangxu0304 commented 2 years ago

The main reason for the poor performance of DGC might be the for loop in the compression phase. You can also compare it with the official DGC implementation.

Snoeprol commented 2 years ago

The main reason for the poor performance of DGC might be the for loop in the compression phase. You can also compare it with the official DGC implementation.

Thanks for your quick reply. I will investigate it.

Snoeprol commented 2 years ago

I have written a class for DGC which speeds up the training significantly. I will release it when my project is done :)

Snoeprol commented 2 years ago

class DGC(Optimizer):
    r"""Implements deep gradient compression (optionally with momentum).
    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.
        Considering the specific case of Momentum, the update can be written as
        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
            \end{aligned}
        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 
        parameters, gradient, velocity, and momentum respectively.
        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form
        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
                p_{t+1} & = p_{t} - v_{t+1}.
            \end{aligned}
        The Nesterov version is analogously modified.
    """

    def __init__(self, model, lr=required, threshold_method=required, compression=required,
                momentum=0, dampening=0, numel=None, decay=1,
                weight_decay=0, nesterov=False, threshold_momentum=0.9,
                threshold=0.5):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, numel=numel)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(DGC, self).__init__(model.parameters(), defaults)

        self.momentum = momentum
        self.model = model
        self.names_shapes = self._get_names_shapes()
        self.device = next(model.parameters()).device
        self.numel = numel
        self.decay = decay
        self.threshold_method = threshold_method
        self.actual_elements_sent = 0
        self.threshold_momentum = threshold_momentum
        self.threshold = threshold
        self.last_res_norm = None
        self.compression = compression
        self.nesterov = nesterov

        self.id = hvd.rank()
        self.size = hvd.size()

        self.K = numel/(compression * self.size)
        # Set residual and temporary vectors
        self.residual_grad = torch.zeros((numel), dtype=torch.float32, device=self.device)
        self.residual_velo = torch.zeros((numel), dtype=torch.float32, device=self.device)
        self.aggregate = torch.zeros((numel), dtype=torch.float32, device=self.device)
        self.sent_res = torch.zeros((numel), dtype=torch.float32, device=self.device)
        self.sent_vel = torch.zeros((numel), dtype=torch.float32, device=self.device)
        self.grad_tens = torch.zeros((numel), dtype=torch.float32, device=self.device)

        shared = torch.tensor([0.25, 0.0625, 0.025725, 0.03, 0.01, 0.001])
        #shared = torch.tensor([0.25, 93.75, 98.4375, 99.6, 99.9])# [75, 6.25, 2.5725, 0.3, 0.01]
        if self.device != 'cpu':
            self.shared = shared.cuda()

    def __setstate__(self, state):
        super(DGC, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def set_grad(self, epoch):
        self.epoch = epoch
        # Extract gradient calculated in this round
        self._update_residual()
        self._obtain_mask()
        self._save_sent_vals()
        self._aggregate_res()
        self._remove_sent_vals()
        self._insert_grad()
        self._decay_residual()

    def _decay_residual(self):
        self.residual_grad *= self.decay

    def _remove_sent_vals(self):
        # TODO improvement in performance here
        #         self.residual_grad -= self.residual_grad[self.mask]
        #         self.residual_velo -= self.residual_velo[self.mask]
        self.residual_grad -= self.sent_res
        self.residual_velo -= self.sent_vel

    def _aggregate_res(self):

        # all_sent_vel_short = hvd.allgather(self.sent_vel_short)
        # self.aggregate[:] = 0
        # scattered_sent_vel = [sent_vel_short.scatter(-1, self.mask, self.aggregate) for sent_vel_short in all_sent_vel_short]

        # #agg_res.scatter_(-1, self.mask, self.agg_res_short)
        # self.agg_res = torch.stack(scattered_sent_vel, dim=0).sum(dim=0)
        self.agg_res = hvd.allreduce(self.sent_vel, op=hvd.Sum)

    def _obtain_mask(self):
        """
        Set the indices of the object for all
        indices that exceed a threshold. Also
        sets the maximum value of the gradient
        Parameters
        ----------
        tensor: array like
            Tensor of which to select elements
        Returns
        -------
        None      
        """
        abs_tens = torch.abs(self.residual_velo)
        #self.indices,  = torch.where(abs_tens > self.threshold)
        #print(self.indices)
        if self.epoch < 5:
            self.K = self.shared[self.epoch - 1] * self.numel
        else:
            self.K = self.shared[-1] * self.numel

        if not self.topk:
            self.mask = self._compress(abs_tens)#torch.topk(abs_tens, int(self.K))[1]
        else:
            print('using torch topk')
            self.mask = torch.topk(abs_tens, int(self.K))[1].to(torch.int32)
        #self.mask = self._get_top_indices(abs_tens, compression=self.compression)#torch.topk(abs_tens, int(self.K))[1]
        wandb.log({'compression' : self.numel/len(self.mask)})
        #print(self.indices)

    def _update_residual(self):
        tensor_grad = self._list_grad_to_array_grad()
        # Clip it
        tensor_squ_sum = torch.sum(tensor_grad * tensor_grad)
        clipping_val = torch.sqrt(hvd.allreduce(tensor_squ_sum))

        tensor_grad = tensor_grad.clamp(-clipping_val, clipping_val)

        # Add it to residual  
        if self.nesterov:
            self.residual_grad = self.momentum * (self.residual_grad + tensor_grad)
            self.residual_velo += (self.residual_grad + tensor_grad) 
        else:     
            self.residual_grad = self.momentum * self.residual_grad + tensor_grad
            self.residual_velo += self.residual_grad

    def _compress(self, tensor):

        # This may still cause errors in the future
        # 0.001
        #shared = torch.tensor([0.75, 0.0625, 0.025725, 0.03, 0.01, 0.001])
        #shared = torch.tensor([0.25, 93.75, 98.4375, 99.6, 99.9])# [75, 6.25, 2.5725, 0.3, 0.01]
        #shared = shared.cuda()
        # Number of grad entries left out[0.25, 93.75, 98.4375, 99.6, 99.9]

        shape = tensor.size()
        tensor = tensor.flatten()
        numel = tensor.numel()

        sample_shape = [max(1, int(numel * 0.01))]
        sample_index = torch.empty(sample_shape).uniform_(0, numel).type(torch.long).cuda()

        sample_tensor = tensor[sample_index]

        k = max(1, int(self.K * 0.01))
        vals, indices = torch.topk(sample_tensor, k)

        thr = vals.min()
        mask = tensor.abs() >= thr
        selected = mask.sum()

        for _ in range(10):
            if selected > 1.3 * self.K:
                thr = 1.3 * thr
            elif selected < 0.8 * self.K:
                thr = 0.8 * thr
            else:
                break
            mask = tensor >= thr
            selected = mask.sum()

        indices, = torch.where(mask)
        return indices

    def _extract_grad(self) -> list:
        """
        Given model on which a backward step was done,
        extract the gradient in a list of tensors
        Parameters
        ----------
        model : 
                nn model of which the gradient is 
                extracted
        Returns
        -------
        grads : Gradients of the different layers
        """    
        grads = []
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                grads.append((name, param.grad))
        return grads

    def _save_sent_vals(self):
        self.sent_res[:] = 0
        self.sent_vel[:] = 0
        self.sent_res[self.mask] = self.residual_grad[self.mask]
        self.sent_vel[self.mask] = self.residual_velo[self.mask]

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            # Removed code here momentum already happens
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad

                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                p.add_(d_p, alpha=-group['lr'])

        return loss

    def _get_names_shapes(self):
        """
        Get the shapes of the parameters of a model.
        This function extracts the shapes of all the parameters
        using the state dictionary of a given model
        Parameters
        ----------
        model : pytorch model
            Input model of which the parameters will be extracted.
        Returns
        -------
        shapes : list
            List of shapes of every layer of the model
        """
        names_shapes = []
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                names_shapes.append((name, param.data.shape))
        return names_shapes

    def _list_grad_to_array_grad(self) -> torch.Tensor:
        self.grad_tens[:] = 0
        ugly_gradient = self._extract_grad()
        ind_count = 0
        for sub_grad in ugly_gradient:
            flat_sub_grad = torch.flatten(sub_grad[1])
            sub_grad_len = len(flat_sub_grad)
            self.grad_tens[ind_count: ind_count + sub_grad_len] = flat_sub_grad
            ind_count += sub_grad_len
        assert(len(self.grad_tens) == ind_count)

        return self.grad_tens

    def count_grad_length(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def _compress_no_noise(self):
        """
        Scatter the values of a full gradient to a vector
        having the same number of elements as the mask, so
        it can be allreduced in an efficient way.
        Values are scattered according to the indices in the mask
        >>> comp.compress_no_noise([1, 2, 3, 4, 5], [0, 3])
        >>> [1, 4]
        Parameters
        ----------
        grad array like
            Tensor that is compressed
        mask array like
            Mask containing indices
            according to which gradient
            is compressed
        Returns
        -------
        compressed_grad array like
            Compressed version of the gradient
        """        
        self.compressed_grad = self.residual_velo[self.mask]

    def _insert_grad(self):
        """
        Given model, inserts a gradient into the model
        Parameters
        ----------
        model : 
                nn model of which the gradient is 
                extracted
        grads : 
                list of tensors of the to be insterted gradients
        Returns
        -------
        model : 
                model with updated gradient
        """
        # Creates the gradient list object
        self._array_grad_to_list_grad()

        idx = 0
        for (_, param) in self.model.named_parameters():
            if param.requires_grad:
                # Detach fixes the backward error
                param.grad = self.grad_list[idx][1].detach()
                idx += 1

    def _array_grad_to_list_grad(self):
        grad_list = []
        cur_idx = 0
        for idx, grad in enumerate(range(len(self.names_shapes))):
            # Shape of i-th tensor
            shape = self.names_shapes[idx][1]
            size = 1

            for axis_shape in shape:
                size *= axis_shape
            reshaped_tens = self.agg_res[cur_idx:cur_idx + size].reshape(shape)
            #print(f"tensor grad device: {tensor_grad.device}")
            grad_list.append((self.names_shapes[idx][0], reshaped_tens))
            cur_idx += size
        self.grad_list = grad_list

    def _gen_threshold_from_normal_distribution(self, p_value, mu, sigma):
        zvalue = stats.norm.ppf( (1-p_value) / 2)
        left_bound = mu + zvalue * sigma
        right_bound = mu - zvalue * sigma
        return left_bound, right_bound

    def _get_top_indices(self, tensor, compression = None, ratio=0.05):
        with torch.no_grad():
            numel = tensor.numel()
            if compression != None:
                k = max(int(numel/compression), 1)
            else:
                k = max(int(numel * ratio), 1)

            t_norm = tensor.norm(2)

            norm_tensor = tensor / t_norm
            abs_norm_tensor = norm_tensor.abs()

            if torch.__version__ < '1.3.0':
                t_std = torch.std(abs_norm_tensor)
                t_mean = torch.mean(abs_norm_tensor)
            else:
                t_std, t_mean = torch.std_mean(abs_norm_tensor)

            left_thres, right_thres = self._gen_threshold_from_normal_distribution(1 - ratio, float(t_mean), float(t_std))

            loops = 0
            while loops < 3:
                one_indexes = abs_norm_tensor > right_thres
                indices = one_indexes.nonzero().data.squeeze().view(-1)
                if indices.numel() < 2 * k / 3:
                    right_thres *= 0.5
                elif indices.numel() > 4 * k / 3:
                    right_thres *= 1.5
                else:
                    break
                loops += 1
        return indices