ModelTC / United-Perception

United Perception
Apache License 2.0
428 stars 65 forks source link

EFL使用直接计算grad的方法 #41

Open dongnana777 opened 2 years ago

dongnana777 commented 2 years ago

请问你有没有在deformable-detr上使用过这个loss。 我现在想在deformable-detr上使用你文章中的loss,但是复现的时候遇到了一些问题想请教你,下面是我的复现代码。 我的问题是:1, 我在代码中标注的变量的维度正确吗? 2, 如果正确的话,我在计算focusing factor时遇到了一个错误,错误我写在了代码里面。 3,我想直接计算梯度,我的代码中计算梯度的方式正确吗?

  1. 你的代码中的ignore_index是起到了什么作用呢?产生的mask是想去除掉什么呢?
dongnana777 commented 2 years ago
import torch
import torch.nn.functional as F
from models.loss import _reduce
from models.entropy_loss import GeneralizedCrossEntropyLoss

import torch.distributed as dist
try:
    import spring.linklink as link
except:   # noqa
    link = None

class DistBackend():
    def __init__(self):
        # self.backend = 'linklink'
        self.backend = 'dist'
DIST_BACKEND = DistBackend()

def allreduce(*args, **kwargs):
    if DIST_BACKEND.backend == 'linklink':
        return link.allreduce(*args, **kwargs)
    elif DIST_BACKEND.backend == 'dist':
        return dist.all_reduce(*args, **kwargs)
    else:
        raise NotImplementedError
 def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        # target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
        # target_classes[idx] = target_classes_o

        target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        # target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
        #                                     dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        # target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
        # target_classes_onehot = target_classes_onehot[:,:,:-1]

        loss_ce = self.efl(src_logits, target_classes) / num_boxes
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

class EqualizedFocalLoss(GeneralizedCrossEntropyLoss):
    def __init__(self,
                 name='equalized_focal_loss',
                 reduction='mean',
                 loss_weight=1.0,
                 ignore_index=-1,
                 num_classes=1204,
                 focal_gamma=2.0,
                 focal_alpha=0.25,
                 scale_factor=8.0,
                 fpn_levels=5):
        activation_type = 'sigmoid'
        GeneralizedCrossEntropyLoss.__init__(self,
                                             name=name,
                                             reduction=reduction,
                                             loss_weight=loss_weight,
                                             activation_type=activation_type,
                                             ignore_index=ignore_index)

        # cfg for focal loss
        self.focal_gamma = focal_gamma
        self.focal_alpha = focal_alpha

        # ignore bg class and ignore idx
        self.num_classes = num_classes - 1

        # cfg for efl loss
        self.scale_factor = scale_factor
        # initial variables
        self.register_buffer('pos_grad', torch.zeros(self.num_classes))
        self.register_buffer('neg_grad', torch.zeros(self.num_classes))
        self.register_buffer('pos_neg', torch.ones(self.num_classes))

        # grad collect
        self.grad_buffer = []
        self.fpn_levels = fpn_levels

    def forward(self, input, target, reduction, normalizer=None):
        """
        input: [1, 300, 1204]
        target: [1, 300]
        self.pos_neg: [1203]
        self.n_c: 1204
        self.n_i: 300
        """
        self.n_c = input.shape[-1]
        self.input = input.reshape(-1, self.n_c)
        self.target = target.reshape(-1)
        self.n_i, _ = self.input.size()

        def expand_label(pred, gt_classes):
            target = pred.new_zeros(self.n_i, self.n_c + 1)
            target[torch.arange(self.n_i), gt_classes] = 1
            return target[:, 1:]

        expand_target = expand_label(self.input, self.target)
        sample_mask = (self.target != self.ignore_index)

        inputs = self.input[sample_mask]
        targets = expand_target[sample_mask]
        self.cache_mask = sample_mask
        self.cache_target = expand_target

        pred = torch.sigmoid(inputs)
        pred_t = pred * targets + (1 - pred) * (1 - targets)

        # collect gradient
        self.collect_grad(inputs.detach(), targets.detach())

        map_val = 1 - self.pos_neg.detach()
        dy_gamma = self.focal_gamma + self.scale_factor * map_val
        # focusing factor
        ff = dy_gamma.view(1, -1).expand(self.n_i, self.n_c)[sample_mask]
        # RuntimeError: The expanded size of the tensor (1204) must match the existing size (1203) at non-singleton dimension 1.
        # Target sizes: [300, 1204].  Tensor sizes: [1, 1203]

        # weighting factor
        wf = ff / self.focal_gamma

        # ce_loss
        ce_loss = -torch.log(pred_t)
        cls_loss = ce_loss * torch.pow((1 - pred_t), ff.detach()) * wf.detach()

        # to avoid an OOM error
        # torch.cuda.empty_cache()

        if self.focal_alpha >= 0:
            alpha_t = self.focal_alpha * targets + (1 - self.focal_alpha) * (1 - targets)
            cls_loss = alpha_t * cls_loss

        if normalizer is None:
            normalizer = 1.0

        return _reduce(cls_loss, reduction, normalizer=normalizer)

    def collect_grad(self, inputs, targets):
        inputs.requires_grad_(True)

        pred = torch.sigmoid(inputs)
        loss = pred * targets + (1 - pred) * (1 - targets)
        # loss = self.sigmoid_focal_loss(inputs, targets)

        loss.backward(torch.ones_like(loss))
        grad = inputs.grad

        grad = torch.abs(grad)[self.cache_mask]

        # do not collect grad for objectiveness branch [:-1]
        pos_grad = torch.sum(grad * targets, dim=0)[:-1]
        neg_grad = torch.sum(grad * (1 - targets), dim=0)[:-1]

        # allreduce(pos_grad)
        # allreduce(neg_grad)

        self.pos_grad += pos_grad
        self.neg_grad += neg_grad
        temp = self.pos_grad[self.pos_grad > 0]
        self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)

    def sigmoid_focal_loss(self, inputs, targets, alpha: float = 0.25, gamma: float = 2):
        """
        Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
        Args:
            inputs: A float tensor of arbitrary shape.
                    The predictions for each example.
            targets: A float tensor with the same shape as inputs. Stores the binary
                     classification label for each element in inputs
                    (0 for the negative class and 1 for the positive class).
            alpha: (optional) Weighting factor in range (0,1) to balance
                    positive vs negative examples. Default = -1 (no weighting).
            gamma: Exponent of the modulating factor (1 - p_t) to
                   balance easy vs hard examples.
        Returns:
            Loss tensor
        """
        prob = inputs.sigmoid()
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        p_t = prob * targets + (1 - prob) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** gamma)

        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            loss = alpha_t * loss

        # return loss.mean(1).sum()
        return loss.mean(1).sum()
waveboo commented 2 years ago

Hi @dongnana777 , Thanks for your questions. For your first question, your input shape may be incorrect. Suppose we have (n+1) classes (where 1 is for the background). If you want to use a softmax-based loss, the input dim should be n+1, but if you want to use a sigmoid-based loss (e.g., Focal Loss, EFL), the input dim should be n and the one-hot label of the background class is denoted as [0,0,...,0] (all zeros). If you have a objectness branch or centerness branch, please split it out and compute the loss individually. For the second question, it is caused by the first question. For the third question, we highly recommend you to use to hook method to collect the gradient like our implementation. The code of your collect function is somewhere wrong because your loss calculation is incorrect. (You use a sigmoid function loss to calculate the EFL gradient). And if you want to collect the gradient by your function, you need to check the gradient value carefully and compare it with the value of the hook method to ensure your gradient is correct. For the last question, the ignore_index is only used for the RetinaNet detector. It will ignore the samples when it has 0.4 <iou <0.5 with gt. For other methods, it will always be true.

dongnana777 commented 2 years ago

你好,deformable-detr 也使用的focal loss, 但是函数的输入维度是n+1 (1203+1), 背景类的label是[0, 0,...,0, ...,0], 目标类的label是[0, 0,...,1,...0]. 其次,因为def-detr是每个输出分别计算loss的,直接计算梯度比较方便。直接计算梯度的方法应该使用focal loss的梯度对吗?如下图我的代码,请问直接计算梯度的部分正确吗?

dongnana777 commented 2 years ago

class EqualizedFocalLoss(GeneralizedCrossEntropyLoss): def init(self, name='equalized_focal_loss', reduction='mean', loss_weight=1.0, ignore_index=-1, num_classes=1204, focal_gamma=2.0, focal_alpha=0.25, scale_factor=8.0, fpn_levels=5): activation_type = 'sigmoid' GeneralizedCrossEntropyLoss.init(self, name=name, reduction=reduction, loss_weight=loss_weight, activation_type=activation_type, ignore_index=ignore_index)

      # cfg for focal loss
      self.focal_gamma = focal_gamma
      self.focal_alpha = focal_alpha

      # ignore bg class and ignore idx
      self.num_classes = num_classes - 1

      # cfg for efl loss
      self.scale_factor = scale_factor
      # initial variables
      self.register_buffer('pos_grad', torch.zeros(self.num_classes))
      self.register_buffer('neg_grad', torch.zeros(self.num_classes))
      self.register_buffer('pos_neg', torch.ones(self.num_classes))

      # grad collect
      self.grad_buffer = []
      self.fpn_levels = fpn_levels

  def forward(self, input, target, reduction, normalizer=None):
      """
      input: [1, 300, 1204]
      target: [1, 300]
      self.pos_neg: [1203]
      self.n_c: 1204
      self.n_i: 300
      """
      self.n_c = input.shape[-1]
      self.input = input.reshape(-1, self.n_c)
      self.target = target.reshape(-1)
      self.n_i, _ = self.input.size()

      def expand_label(pred, gt_classes):
          target = pred.new_zeros(self.n_i, self.n_c + 1)
          target[torch.arange(self.n_i), gt_classes] = 1
          return target[:, 1:]

      expand_target = expand_label(self.input, self.target)
      sample_mask = (self.target != self.ignore_index)

      inputs = self.input[sample_mask]
      targets = expand_target[sample_mask]
      self.cache_mask = sample_mask
      self.cache_target = expand_target

      pred = torch.sigmoid(inputs)
      pred_t = pred * targets + (1 - pred) * (1 - targets)

      # collect gradient
      self.collect_grad(inputs.detach(), targets.detach())

      map_val = 1 - self.pos_neg.detach()
      dy_gamma = self.focal_gamma + self.scale_factor * map_val
      # focusing factor
      ff = dy_gamma.view(1, -1).expand(self.n_i, self.n_c)[sample_mask]
      # RuntimeError: The expanded size of the tensor (1204) must match the existing size (1203) at non-singleton dimension 1.
      # Target sizes: [300, 1204].  Tensor sizes: [1, 1203]

      # weighting factor
      wf = ff / self.focal_gamma

      # ce_loss
      ce_loss = -torch.log(pred_t)
      cls_loss = ce_loss * torch.pow((1 - pred_t), ff.detach()) * wf.detach()

      # to avoid an OOM error
      # torch.cuda.empty_cache()

      if self.focal_alpha >= 0:
          alpha_t = self.focal_alpha * targets + (1 - self.focal_alpha) * (1 - targets)
          cls_loss = alpha_t * cls_loss

      if normalizer is None:
          normalizer = 1.0

      return _reduce(cls_loss, reduction, normalizer=normalizer)

  def collect_grad(self, inputs, targets):
      inputs.requires_grad_(True)

      loss = self.sigmoid_focal_loss(inputs, targets)

      loss.backward(torch.ones_like(loss))
      grad = inputs.grad

      grad = torch.abs(grad)[self.cache_mask]

      # do not collect grad for objectiveness branch [:-1]
      pos_grad = torch.sum(grad * targets, dim=0)
      neg_grad = torch.sum(grad * (1 - targets), dim=0)

      # allreduce(pos_grad)
      # allreduce(neg_grad)

      self.pos_grad += pos_grad
      self.neg_grad += neg_grad
      temp = self.pos_grad[self.pos_grad > 0]
      self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)

  def sigmoid_focal_loss(self, inputs, targets, alpha: float = 0.25, gamma: float = 2):
      """
      Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
      Args:
          inputs: A float tensor of arbitrary shape.
                  The predictions for each example.
          targets: A float tensor with the same shape as inputs. Stores the binary
                   classification label for each element in inputs
                  (0 for the negative class and 1 for the positive class).
          alpha: (optional) Weighting factor in range (0,1) to balance
                  positive vs negative examples. Default = -1 (no weighting).
          gamma: Exponent of the modulating factor (1 - p_t) to
                 balance easy vs hard examples.
      Returns:
          Loss tensor
      """
      prob = inputs.sigmoid()
      ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
      p_t = prob * targets + (1 - prob) * (1 - targets)
      loss = ce_loss * ((1 - p_t) ** gamma)

      if alpha >= 0:
          alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
          loss = alpha_t * loss

      # return loss.mean(1).sum()
      return loss
waveboo commented 2 years ago

Hi @dongnana777,

  1. the official implementation of deformable-detr seems only has n classes focal loss. (20 classes for VOC and 91 classes for COCO).
  2. I can not understand the second problem "因为def-detr是每个输出分别计算loss的,直接计算梯度比较方便". If your meaning is the aux_loss in deformable_detr, the hook method seems to be very convenient too. Actually, RetianNet, FCOS, and ATSS all have 5 sub-networks (corresponding to the 5 fpn levels) for the final classifier, we use the hook method to collect their gradient without much difficulty. Meanwhile, you adopt the focal loss function to calculate the EFL gradient, which is also incorrect.
dongnana777 commented 2 years ago

官方的deformable-detr的代码只在coco上实现的,在coco上使用91个classes, 但是coco最大的类是90,所以+1是对背景类的。 @waveboo

dongnana777 commented 2 years ago

是的,我的意思就是aux_loss的存在可能会造成使用hook比较困难。请问我如果在deformable-detr的代码里使用hook,我应该在什么位置收集梯度呢?

waveboo commented 2 years ago

For coco dataset, it has 91 classes. Modern detectors map the 91 categories into 80 classes. Please check it by yourself. For the gradient hook, you just need to register_backward_hook on the class_embed layer in deformable detr. After collect the gradient, you need to write your own code the integrate the gradient from different aux_loss in the collect_grad fucntion.

dongnana777 commented 2 years ago

image

image @waveboo

waveboo commented 2 years ago

@dongnana777 https://github.com/fundamentalvision/Deformable-DETR/issues/152

dongnana777 commented 2 years ago

哇,谢谢你,受教了. @waveboo

dongnana777 commented 2 years ago

我还有一个问题,我尝试在deformable-detr中使用hook,但是收集梯度的时候会报错,大概是efl中不存在self.cache_target,这是因为什么? @waveboo

dongnana777 commented 2 years ago

你好,在def-detr上使用lvis数据集,将类别定义为1203,多卡跑程序会报错,请问你遇到这种情况了吗? @waveboo

YisWa commented 1 year ago

你好,在def-detr上使用lvis数据集,将类别定义为1203,多卡跑程序会报错,请问你遇到这种情况了吗? @waveboo

你好!请问你在deformable-detr上复现成功了吗?

xiaoche-24 commented 9 months ago

我还有一个问题,我尝试在deformable-detr中使用hook,但是收集梯度的时候会报错,大概是efl中不存在self.cache_target,这是因为什么? @waveboo

老哥,我在yolov5上应用efl,也存在收集梯度实现的问题,请问你后续实现采用hook收集梯度吗?可以交流一下不 @dongnana777