FangShancheng / ABINet

Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
Other
420 stars 72 forks source link

CTCLoss is nan #65

Closed dagongji10 closed 2 years ago

dagongji10 commented 2 years ago

@FangShancheng @AK391 When I use CRNN+MultiLosses for training, the loss is normal. But if change to CRNN+CTCLoss, the loss of first 4~5 iterations is normal, then get inf, after that get nan. Is there any problem in my implementation of CTCLoss ?

from torch.nn import CTCLoss

class MyCTCLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ctc = CTCLoss(
                reduction="mean",
                zero_infinity=False)

    @property
    def last_losses(self):
        return self.losses

    def _flatten(self, sources, lengths):
        return torch.cat([t[:l] for t, l in zip(sources, lengths)])

    def _ctc_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
        loss_name = output.get('name')
        pt_logits, weight = output['logits'], output['loss_weight']
        assert pt_logits.shape[0] % gt_labels.shape[0] == 0

        pt_logits = pt_logits.permute(1, 0, 2)
        log_pt_logits = torch.log_softmax(pt_logits, dim=-1)
        pt_lengths = output['pt_lengths']
        flat_gt_labels = self._flatten(gt_labels, gt_lengths - 1)
        loss = self.ctc(log_pt_logits, flat_gt_labels, pt_lengths, gt_lengths - 1)

        if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss

        return loss

    def forward(self, outputs, *args):
        self.losses = {}
        return self._ctc_loss(outputs, *args, record=False)
dagongji10 commented 2 years ago

The problem is output['pt_lengths'] is from _get_length. But this is not suitable for CTCLoss. Modify it with

pt_lengths = torch.tensor([pt_logits.size(0)] * pt_logits.size(1))

The loss will be normal.