baidu-research / warp-ctc

Fast parallel CTC.
Apache License 2.0
4.07k stars 1.04k forks source link

How to add focal loss in CTC? #143

Open Xujianzhong opened 5 years ago

Xujianzhong commented 5 years ago

I want to introduce focal loss in CTC loss in the engineering, and I don't know where to add it. I hope a god can help me. Thanks !

xingchensong commented 5 years ago

me either

Xujianzhong commented 5 years ago

@stephen-song (为了方便交流,请容许我使用中文) 目前ctc里面loss用的交叉熵方式:L=-log(p),我希望修改这个loss的计算方式为:L=(1-p)log(p)这样的形式,请问如何修改代码。 不胜感激!

knowledge3 commented 5 years ago

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

Tangzixia commented 5 years ago

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

大佬你好,请问可以发一下你的代码吗?tensorflow版的,对于α还有一些参数存疑,还望解答

ChChwang commented 3 years ago

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

请问focal loss 加到 ctc loss上有效果吗?

blueardour commented 3 years ago

根据上面讨论想到大概的代码(未验证,最近会跑跑看):

criterion = torch.nn.CTCLoss()
# fm shape: [N, B, C] 
output = F.log_softmax(fm, dim=2)
ctc_loss = criterion(output, target, pred_size, target_length)

p = torch.exp(-ctc_loss)
focal_loss = -(1-p) * torch.log(p)
blueardour commented 3 years ago

[update] no benefit in my training with above code.

huangxin168 commented 2 years ago

这是paddle的:

import paddle from paddle import nn

class CTCLoss(nn.Layer):

  def __init__(self, use_focal_loss=False, **kwargs):
      super(CTCLoss, self).__init__()
      self.loss_func = nn.CTCLoss(blank=0, reduction='none')
      self.use_focal_loss = use_focal_loss

  def forward(self, predicts, batch):
      if isinstance(predicts, (list, tuple)):
          predicts = predicts[-1]
      predicts = predicts.transpose((1, 0, 2))
      N, B, _ = predicts.shape
      preds_lengths = paddle.to_tensor(
          [N] * B, dtype='int64', place=paddle.CPUPlace())
      labels = batch[1].astype("int32")
      label_lengths = batch[2].astype('int64')
      loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
      if self.use_focal_loss:
          weight = paddle.exp(-loss)
          weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
          weight = paddle.square(weight)
          loss = paddle.multiply(loss, weight)
      loss = loss.mean()
      return {'loss': loss}