SeanNaren / warp-ctc

Pytorch Bindings for warp-ctc
Apache License 2.0
756 stars 271 forks source link

Doubts about the difference between pytorch's own ctcloss and warp-ctc #191

Open 2017ZYS opened 3 years ago

2017ZYS commented 3 years ago

My test environment is as follows: ######################################################## torch==1.4.0 torchvision==0.5.0 cuda==9.0/10.1 ########################################################

My test code is as follows: ########################################################

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from warpctc_pytorch import CTCLoss as warpctc
from torch.nn import CTCLoss as pytorchctc
from torch.autograd import Variable

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]
alphabet = "".join(CHARS) + 'ç'  
alphabet_dict = {}

for i, char in enumerate(alphabet):
    alphabet_dict[char] = i + 1
length = []
result = []
text = ["湘E269JY","冀PL3N67","川R63728F","津AD6849","苏SDFD45464"]

for str in text:
    length.append(len(str))
    for char in str:
        # print(char)
        index = alphabet_dict[char]
        result.append(index)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
targets = torch.IntTensor(result)
targets_lengths = torch.IntTensor(length)
print(targets, targets_lengths)
########
T = 71
N = len(text)
print("N is",N)
C = len(alphabet)
outputs = torch.randn(T,N,C).to(device)
log_probs = outputs.log_softmax(2).detach().requires_grad_().to(device)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
print(input_lengths,input_lengths.shape)
########
#warp_ctc
criterion_warp = warpctc(blank=0,size_average=False,length_average=False).to(device)
loss_warp1 = criterion_warp1(outputs,targets,input_lengths,targets_lengths)
loss_warp2 = criterion_warp1(log_probs,targets,input_lengths,targets_lengths)
print(loss_warp1,loss_warp2)
#######
#pytorch ctc
criterion_none = pytorchctc(blank=0,reduction="none")
loss_pytorch = criterion_none(log_probs,targets,input_lengths,targets_lengths)
print(loss_pytorch)

########################################################

The printed result is

tensor([19, 46, 34, 38, 41, 50, 64,  5, 55, 52, 35, 54, 38, 39, 23, 57, 38, 35,
        39, 34, 40, 47,  3, 42, 45, 38, 40, 36, 41, 11, 58, 45, 47, 45, 36, 37,
        36, 38, 36], dtype=torch.int32)  tensor([ 7,  7,  8,  7, 10], dtype=torch.int32)
N is 5
tensor([71, 71, 71, 71, 71]) torch.Size([5])
tensor([828.2159]) tensor([828.2159], grad_fn=<_CTCBackward>)
tensor([278.4382, 289.9979, 276.8926, 278.0664, 272.8851], device='cuda:0',grad_fn=<CudnnCtcLossBackward>)

######################################################## Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")? I find that 278.4382+289.9979+276.8926+278.0664+272.8851 is not equal to 828.2159,however 278.4382+276.8926+272.8851 is equal to 828.2159. It maybe means only the even indexed data is used to calculate the loss by warp_ctc, no loss is calculated for odd indexed data. Is it a bug?

DYJNG commented 3 years ago

Same question ! Look forward to your reply...

Zhang-O commented 2 years ago

Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")? A: yes You should set revise your code "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)" with "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int32)"