SeanNaren / warp-ctc

Pytorch Bindings for warp-ctc
Apache License 2.0
756 stars 271 forks source link

Inconsistent result with torch.nn.CTCLoss #189

Closed discort closed 3 years ago

discort commented 3 years ago

Background:

warpctc_pytorch.CTCLoss results inf loss however torch.nn.CTCLoss calculates float value.

Expected

I'm not sure what loss works correctly, but I suppose we should get the same result, right? Or did I miss something in terms of passing incorrect args to ctcloss?

To reproduce:

import torch
import torch.nn as nn
from warpctc_pytorch import CTCLoss

out = torch.randn(368, 2, 29).requires_grad_()
input_lengths = torch.tensor([119, 179])
labels = torch.randint(1, 27, size=(148,), dtype=torch.long)
label_lengths = torch.tensor([59, 89])

warp_ctc = CTCLoss(size_average=True)
torch_ctc = nn.CTCLoss(reduction='sum')

In [21]: loss1 = warp_ctc(out, labels, input_lengths, label_lengths)
In [22]: loss1
Out[22]: tensor([inf], grad_fn=<_CTCBackward>)

In [24]: loss2 = torch_ctc(out.log_softmax(2), labels, input_lengths, label_lengths)
In [25]: loss2
Out[25]: tensor(771.5262, grad_fn=<SumBackward0>)

Version

Python 3.8.6 (default, Oct 13 2020, 23:19:59)
[Clang 9.0.0 (clang-900.0.39.2)] on darwin
In [17]: torch.__version__
Out[17]: '1.8.0'

UPD Ran the same code on aarch64 GNU/Linux on cpu:

double free or corruption (!prev)

on cuda:

tensor([0.], grad_fn=<_CTCBackward>)
tensor(771.5262, device='cuda:0', grad_fn=<SumBackward0>)
discort commented 3 years ago

But a simple example from docs works as expected

out = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).requires_grad_(True)
print(out.shape)
labels = torch.IntTensor([1, 2])
label_lengths = torch.IntTensor([2])
input_lengths = torch.IntTensor([2])

warp_ctc = CTCLoss(size_average=True)
torch_ctc = nn.CTCLoss(reduction='sum')
loss1 = warp_ctc(out, labels, input_lengths, label_lengths)
print(loss1)
loss2 = torch_ctc(out.log_softmax(2), labels, input_lengths, label_lengths)
print(loss2)

out

torch.Size([2, 1, 5])
tensor([2.4629], grad_fn=<_CTCBackward>)
tensor(2.4629, grad_fn=<SumBackward0>)

@SeanNaren any ideas about the origins of double free or corruption (!prev) (cpu) or zeros on cuda?

Zhang-O commented 2 years ago

import torch import torch.nn as nn from warpctc_pytorch import CTCLoss

out = torch.randn(368, 2, 29) input_lengths = torch.IntTensor([119, 179]) labels = torch.randint(1, 27, size=(148,)).int() label_lengths = torch.IntTensor([59, 89])

warp_ctc = CTCLoss() torch_ctc = nn.CTCLoss(reduction='sum')

loss1 = warp_ctc(out, labels, input_lengths, label_lengths) print(loss1)

loss2 = torch_ctc(out.log_softmax(2), labels, input_lengths, label_lengths) print(loss2)

Zhang-O commented 2 years ago

remember follows: torch.tensor([119, 179]) returns torch.int64 , not torch.int32 torch.randint(1, 27, size=(148,)) returns torch.int64 , not torch.int32