@FangShancheng @AK391
When I use CRNN+MultiLosses for training, the loss is normal. But if change to CRNN+CTCLoss, the loss of first 4~5 iterations is normal, then get inf, after that get nan. Is there any problem in my implementation of CTCLoss ?
from torch.nn import CTCLoss
class MyCTCLoss(nn.Module):
def __init__(self):
super().__init__()
self.ctc = CTCLoss(
reduction="mean",
zero_infinity=False)
@property
def last_losses(self):
return self.losses
def _flatten(self, sources, lengths):
return torch.cat([t[:l] for t, l in zip(sources, lengths)])
def _ctc_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
loss_name = output.get('name')
pt_logits, weight = output['logits'], output['loss_weight']
assert pt_logits.shape[0] % gt_labels.shape[0] == 0
pt_logits = pt_logits.permute(1, 0, 2)
log_pt_logits = torch.log_softmax(pt_logits, dim=-1)
pt_lengths = output['pt_lengths']
flat_gt_labels = self._flatten(gt_labels, gt_lengths - 1)
loss = self.ctc(log_pt_logits, flat_gt_labels, pt_lengths, gt_lengths - 1)
if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss
return loss
def forward(self, outputs, *args):
self.losses = {}
return self._ctc_loss(outputs, *args, record=False)
@FangShancheng @AK391 When I use
CRNN+MultiLosses
for training, the loss is normal. But if change toCRNN+CTCLoss
, the loss of first 4~5 iterations is normal, then get inf, after that get nan. Is there any problem in my implementation of CTCLoss ?