JaidedAI / EasyOCR

Ready-to-use OCR with 80+ supported languages and all popular writing scripts including Latin, Chinese, Arabic, Devanagari, Cyrillic and etc.
https://www.jaided.ai
Apache License 2.0
24.33k stars 3.15k forks source link

Transfer-learning to improve accuracy for a specific font and background #317

Closed kurt-stolle closed 3 years ago

kurt-stolle commented 3 years ago

In my specific use-case, I only need to recognize texts of which the font is always the same and the background is always darker than the foreground. Additionally, the character set is smaller than the one that the default latin.pth checkpoint is trained on. The default model does not meet my accuracy requirements, even after fine-tuning each parameter. For this reason, a way to improve accuracy would be to fine-tune the model using generated image-text pairs in the target style of my application.

My current solution loads the default latin1.pth model, but instead uses my own prediction layer that reflects my reduced character-set. Transfer-learning is then performed by only training the weights of this final layer. The training process is set-up as follows:

import tempfile

from os import path
from typing import List

from easyocr.model.model import Model
from easyocr.utils import CTCLabelConverter
from easyocr.recognition import AlignCollate
import easyocr

import torch
import torch.nn.functional
import torch.utils.data

import numpy as np

from .. import textgen, translations

# Omitted: methods that wrap all code below

# Lexicon, a dictionary of words that appear in this specific use-case
lexicon = translations.Lexicon(dict_path)
lexicon.augment_capitalization(len(lexicon))
lexicon.augment_jibberish(round(len(lexicon)/5))
lexicon.augment_numbers(round(len(lexicon)/10))

# Initialize the CTC label converter (see CRNN paper)
character = lexicon.characters()
character.add(' ')

converter = CTCLabelConverter(character)
num_class = len(converter.character)

# Initialize the feature mapping and sequence labeling model
model = Model(input_channel, output_channel, hidden_size, num_class=num_class)
model = torch.nn.DataParallel(model).to(dev)

for name, param in model.named_parameters():
    if 'localization_fc2' in name:
        print(f'Skip {name} as it is already initialized')
        continue
    try:
        if 'bias' in name:
            torch.nn.init.constant_(param, 0.)
        elif 'weight' in name:
            torch.nn.init.kaiming_normal_(param)
    except Exception:  # for batchnorm.
        if 'weight' in name:
            param.data.fill_(1.)
        continue

# Load weights from EasyOCR default checkpoint for latin characters
checkpoint = torch.load("models/latin.pth")
checkpoint["module.Prediction.weight"] = torch.randn((96, 512)) * 0.01
checkpoint["module.Prediction.bias"] = torch.zeros(96)

model.load_state_dict(checkpoint)
model.train()

# Disable training on all layers except the final prediction layer
for param in model.parameters():
        param.requires_grad = False

for param in model.module.Prediction.parameters():
        param.requires_grad = True

# CTC Loss criterion
criterion = torch.nn.CTCLoss(zero_infinity=False).to(dev)

# Optimizer (CRNN paper recommends Adagrad)
filtered_parameters = []
params_num = 0
for p in filter(lambda p: p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num += np.prod(p.size())

print('Trainable params num : ', params_num)

optimizer = torch.optim.Adagrad(filtered_parameters)

# Wrap AlignCollate in our own collate function
def collate(batch):
    ratios = []
    imgs = []
    lbls = []
    for img, lbl in batch:
        lbls.append(lbl)
        ratios.append(float(img.size[0])/float(img.size[1]))
        imgs.append(img)

    imgs = AlignCollate(imgH=64, imgW=int(max(ratios) * 64), keep_ratio_with_pad=True)(imgs)

    return imgs, lbls

# TRDG library requires a directory to look for background images
with tempfile.TemporaryDirectory() as bg_dir:
    # Dataset that yields random (image,text) pairs, sentences of max 3 random word, using TRDG library
    ds = textgen.TextDataset(lexicon, bg_dir)
    train_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, pin_memory=True, drop_last=True, collate_fn=collate)

    # Required for CTCLoss
    torch.backends.cudnn.deterministic = True

    # Training loop
    for (i, (img, lbl)) in enumerate(train_dl):
        img = img.to(dev)

        # Encode the text label
        lbl_encoded, length = converter.encode(lbl)

        # Run the model
        model.zero_grad()
        preds = model(img, None)
        preds_size = torch.IntTensor([preds.size(1)] * img.size(0))
        preds_log_softmax = preds.log_softmax(2)

        # Calculate loss
        cost = criterion(preds_log_softmax.permute(1, 0, 2), lbl_encoded, preds_size, length)

        print(f"Iteration\t{i}:\tcost {cost.item()}")

        # Optimizer step
        cost.backward()
        torch.nn.utils.clip_grad_norm_(filtered_parameters, 5)
        optimizer.step()

        # Show the result every 5 steps
        if i % 5 == 0:
            _, preds_index = preds_log_softmax.detach().max(2)
            preds_index = preds_index.view(-1)
            preds_str = converter.decode_greedy(preds_index.data, preds_size.data)

            for idx, (true_lbl, pred_lbl) in enumerate(zip(lbl, preds_str)):
                print(f"\t- true {idx}\t: {true_lbl}")
                print(f"\t- pred {idx}\t: {pred_lbl}")

While this already significantly improves accuracy, I would like to go further and also train the remaining layers. I notice though that when I try to train all layers simultaneously, the model quickly diverges. It is not clear to me whether this is due to a mistake in the training script or something else that I am not accounting for.

How can I (re-)train the model either from scratch or using latin.pth as a starting point?

Ameya-Manas commented 3 years ago

Hello @kurt-stolle Can you share the structure of the files / folders of textgen and translations? Or if it is publicly available, where can I find it?

kurt-stolle commented 3 years ago

Hi @Ameya-Manas, textgen and translations are packages that help generate label-image pairs via torch.utils.data.IterableDataset in the target style of the OCR-application with some augmentations applied (distortions, warps, etc.). Sadly, company policy does not allow me to share these files here.

Ameya-Manas commented 3 years ago

@kurt-stolle Ok no problem. Thank you anyway. :) Your code has still given me some ideas as to how to proceed. :)

piotrostr commented 3 years ago

"Training pipeline for recognition part is a modified version from this repository. [https://github.com/clovaai/deep-text-recognition-benchmark]" citing readme from https://github.com/JaidedAI/EasyOCR/tree/a5d3053df952ccc411863e1b0d690f3678c9da03

kurt-stolle commented 3 years ago

@piotr-ost Indeed, the code in this issue is too. What is your point?

piotrostr commented 3 years ago

Hey @kurt-stolle. There is a guide on training for your own dataset as well as failure cases under that repository. In the paper they also outline some possible solutions to overcome the failure cases like say low-res images. I faced a similar problem as you and found it helpful so thought it might be worth sharing :)

LanzaMercado commented 3 years ago

Hi to everyone . I wish to get my hand dirty and try to re-train(more likely transfer learning) easyOCR to improve its performance in my dataset. Where exactly is the guide you mention @piotr-ost ?. In my case, what I have are low res images, easyocr performs really well in general. But I want to try improve it depending on each roi extracted. Any help you can dumb down for me will be really appreciated. My first step is to try understand the code given by @kurt-stolle since I have never used pytorch. Thanks in advance !

piotrostr commented 3 years ago

Hi there, @LanzaMercado all the training steps are outlined in https://github.com/clovaai/deep-text-recognition-benchmark, if you are struggling with low resolution images doing superresolution might be the way to improve accuracy. Personally, using the RDN_x4 model from the repository https://github.com/yjn870/RDN-pytorch made the predictions better by a bit. Good luck!

kurt-stolle commented 3 years ago

@LanzaMercado Sadly, this issue was never resolved. I ended up writing my own OCR library with off-the-shelf networks, using a similar three-stage approach as EasyOCR does.