clovaai / deep-text-recognition-benchmark

Text recognition (optical character recognition) with deep learning methods, ICCV 2019
Apache License 2.0
3.77k stars 1.11k forks source link

how to get word confidence score and character level confidence score using CTC? #389

Open PriyaAhirwar opened 1 year ago

PriyaAhirwar commented 1 year ago

Any ideas about how to get the word confidence score and character level confidence score using CTC? I am able to get it using Attn, but not with CTC. Any suggestions would be appreciated.

Thanks.

ImJaewooChoi commented 1 year ago

here is some code. just add it at demo.py

def vis_result(image_path, output_path, text, prob, font_path="fonts/gulim.ttc", font_size= 15, text_position=(0, 30), text_color= (0, 0, 255)): img = Image.open(image_path)

# Define font and size
font = ImageFont.truetype(font_path, font_size)
prob= round(prob.item(),2)
text= text + f"({prob})%"
# Create drawing object and add text
draw = ImageDraw.Draw(img)
draw.text(text_position, text, fill=text_color, font=font)

# Save image with new text
img.save(output_path + os.path.basename(image_path).split('.')[0] + '.jpg')

model.eval()
for image_tensors, image_path_list in demo_loader:
    batch_size = image_tensors.size(0)
    with torch.no_grad():
        image = image_tensors.cuda()
        # For max length prediction
        length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
        text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

    if 'CTC' in opt.Prediction:
        preds = model(image, text_for_pred).log_softmax(2)

        # Select max probabilty (greedy decoding) then decode index to character
        preds_size = torch.IntTensor([preds.size(1)] * batch_size)
        _, preds_index = preds.permute(1, 0, 2).max(2)
        preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
        preds_str = converter.decode(preds_index.data, preds_size.data)

    else:
        preds = model(image, text_for_pred, is_train=False)
        # select max probabilty (greedy decoding) then decode index to character
        _, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index, length_for_pred)

    preds_prob= F.softmax(preds, dim=2)
    preds_max_prob, _ = preds_prob.max(dim=2)

    log = open(f'./log_result.txt', 'a')
    dashed_line = '-' * 80
    head = f'{"image_path":25s}\t\t{"predicted_labels":25s}\t\t{"each score":25s}\t\tconfidence score'

    print(f'{dashed_line}\n{head}\n{dashed_line}')
    log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

    preds_prob = F.softmax(preds, dim=2)
    preds_max_prob, _ = preds_prob.max(dim=2)

    for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
        if 'Attn' in opt.Prediction:
            pred_EOS = pred.find('[s]')
            pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            pred_max_prob = pred_max_prob[:pred_EOS]
            each_prob= list(pred_max_prob.detach().cpu().numpy())
            fin_each_prob= ','.join(map(str, each_prob))
        #calculate confidence score (= multiply of pred_max_prob)
        confidence_score = (pred_max_prob.cumprod(dim=0)[-1]) * 100
        vis_result(img_name, "result/", pred, confidence_score, text_color= (0, 0, 255))       
        print(f'{img_name:25s}\t\t{pred:25s}\t\t{fin_each_prob}\t\t{confidence_score:0.4f}')
        log.write(f'{img_name:25s}\t\t{pred:25s}\t\t{fin_each_prob}\t\t{confidence_score:0.4f}\n')

    log.close()