Open PriyaAhirwar opened 1 year ago
here is some code. just add it at demo.py
def vis_result(image_path, output_path, text, prob, font_path="fonts/gulim.ttc", font_size= 15, text_position=(0, 30), text_color= (0, 0, 255)): img = Image.open(image_path)
# Define font and size
font = ImageFont.truetype(font_path, font_size)
prob= round(prob.item(),2)
text= text + f"({prob})%"
# Create drawing object and add text
draw = ImageDraw.Draw(img)
draw.text(text_position, text, fill=text_color, font=font)
# Save image with new text
img.save(output_path + os.path.basename(image_path).split('.')[0] + '.jpg')
model.eval()
for image_tensors, image_path_list in demo_loader:
batch_size = image_tensors.size(0)
with torch.no_grad():
image = image_tensors.cuda()
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred).log_softmax(2)
# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
_, preds_index = preds.permute(1, 0, 2).max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)
else:
preds = model(image, text_for_pred, is_train=False)
# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)
preds_prob= F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
log = open(f'./log_result.txt', 'a')
dashed_line = '-' * 80
head = f'{"image_path":25s}\t\t{"predicted_labels":25s}\t\t{"each score":25s}\t\tconfidence score'
print(f'{dashed_line}\n{head}\n{dashed_line}')
log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
preds_prob = F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
if 'Attn' in opt.Prediction:
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
each_prob= list(pred_max_prob.detach().cpu().numpy())
fin_each_prob= ','.join(map(str, each_prob))
#calculate confidence score (= multiply of pred_max_prob)
confidence_score = (pred_max_prob.cumprod(dim=0)[-1]) * 100
vis_result(img_name, "result/", pred, confidence_score, text_color= (0, 0, 255))
print(f'{img_name:25s}\t\t{pred:25s}\t\t{fin_each_prob}\t\t{confidence_score:0.4f}')
log.write(f'{img_name:25s}\t\t{pred:25s}\t\t{fin_each_prob}\t\t{confidence_score:0.4f}\n')
log.close()
Any ideas about how to get the word confidence score and character level confidence score using CTC? I am able to get it using Attn, but not with CTC. Any suggestions would be appreciated.
Thanks.