facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.32k stars 2.4k forks source link

loss_ce and class error not decreasing #507

Open saharshleo opened 2 years ago

saharshleo commented 2 years ago

I am trying to train similar model for action classification on videos, It has around 200 classes. Following is the code for loss_labels:

def loss_labels(self, outputs, targets, indices, num_segments, log=True):

        """
        Classification loss (Negative Log Likelihood)
            targets dicts must contain the key "labels" containing a tensor of dim [nb_target_segments]

        Parameters:
            `outputs` (dict) : Output of the model. See forward() for the format.
            `targets` (list) : Ground truth targets of the dataset. 
            `indices` (list) : Bipartite matching of the output and target segments. list (len=batch_size) of tuple of tensors (shape=(2, gt_target_segments)).
            `num_segments` (int) : Average number of target segments accross all nodes, for normalization purposes.
            `log` (boolean) : If True, 'class_error' is also calculated and returned.

        Returns: dict {loss : value} where loss can be 'labels' and/or 'class_error'.
        """

        src_logits = outputs['pred_logits'] # (batch_size, num_queries, num_classes)

        # batch_idx - tensor (nb_target_segments) contains batch numbers AND 
        # src_idx - tensor (nb_target_segments) contains source indices of bipartite matcher
        # eg. [0, 0, 0,   1, 1] AND [2, 14, 88,   3, 91] 
        idx = self._get_src_permutation_idx(indices) 

        # tensor (nb_target_segments) contains class labels
        # eg. [6, 9, 25,   4, 7] (each index represents a class in its batch)
        # print(targets['video_target'], indices, idx)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets['video_target'], indices)])

        # (batch_size, num_queries) where all elements have a value of self.num_classes ('no-action' has an index of self.num_classes)
        target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)

        # (batch_size, num_queries) where class labels are assigned based on batch_idx and src_idx. Other elements have a value of self.num_classes
        target_classes[idx] = target_classes_o

        # (batch_size, num_queries, num_classes + 1)
        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)

        # 1 for positive class, 0 for negative class
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:,:,:-1]

        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_segments, alpha=self.focal_alpha, gamma=self.focal_gamma) * src_logits.shape[1]

        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]  

        return losses
def sigmoid_focal_loss(inputs, targets, num_segments, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """

    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_segments

However loss['class_error'] is always stuck at 100.0, and the loss_ce oscillates between 400 and 500. Can anyone please help me in understanding the cause? Feel free to ask if more information is needed.

violetwzj commented 2 years ago

maybe you can change batch_size and num_queries, it work for me

D10752002 commented 2 years ago

@saharshleo did you find any fixes for this issue? I'm training on custom dataset too and experiencing the same issue. My class error is not decreasing and is constant and not decreasing at all.

yarinbar commented 1 year ago

@D10752002 did you solve it? having the same issue here