JaidedAI / EasyOCR

Ready-to-use OCR with 80+ supported languages and all popular writing scripts including Latin, Chinese, Arabic, Devanagari, Cyrillic and etc.
https://www.jaided.ai
Apache License 2.0
23.96k stars 3.13k forks source link

Use model trained with deep-text-recognition-benchmark #991

Open SofieGeens opened 1 year ago

SofieGeens commented 1 year ago

I trained a custom model with deep-text-recognition-benchmark. I moved the pth file and the model.py file from this repository to the correct locations on my machine and changed the file names to all be the same. I think there is still some problem in the model.py. I get the following error if I try to run:

Traceback (most recent call last): File "C:\Users\sofie\OneDrive\Documenten\unif\2022-2023\masterproef\stageOSG\testOCR.py", line 34, in reader = easyocr.Reader(['en'], recog_network='numbers_model') File "C:\Users\sofie\AppData\Local\Programs\Python\Python310\lib\site-packages\easyocr\easyocr.py", line 227, in init self.recognizer, self.converter = get_recognizer(recog_network, network_params,\ File "C:\Users\sofie\AppData\Local\Programs\Python\Python310\lib\site-packages\easyocr\recognition.py", line 166, in get_recognizer model = model_pkg.Model(num_class=num_class, **network_params) TypeError: Model.init() got an unexpected keyword argument 'num_class'

There is a problem with the .py file I'm pretty sure, but I don't know how to fix it. The model.py file looks like this:

import torch.nn as nn from modules.transformation import TPS_SpatialTransformerNetwork from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor from modules.sequence_modeling import BidirectionalLSTM from modules.prediction import Attention

class Model(nn.Module):

def __init__(self, opt):
    super(Model, self).__init__()
    self.opt = opt
    self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
                   'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}

    """ Transformation """
    if opt.Transformation == 'TPS':
        self.Transformation = TPS_SpatialTransformerNetwork(
            F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
    else:
        print('No Transformation module specified')

    """ FeatureExtraction """
    if opt.FeatureExtraction == 'VGG':
        self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
    elif opt.FeatureExtraction == 'RCNN':
        self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
    elif opt.FeatureExtraction == 'ResNet':
        self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
    else:
        raise Exception('No FeatureExtraction module specified')
    self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
    self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

    """ Sequence modeling"""
    if opt.SequenceModeling == 'BiLSTM':
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
            BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
        self.SequenceModeling_output = opt.hidden_size
    else:
        print('No SequenceModeling module specified')
        self.SequenceModeling_output = self.FeatureExtraction_output

    """ Prediction """
    if opt.Prediction == 'CTC':
        self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
    elif opt.Prediction == 'Attn':
        self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
    else:
        raise Exception('Prediction is neither CTC or Attn')

def forward(self, input, text, is_train=True):
    """ Transformation stage """
    if not self.stages['Trans'] == "None":
        input = self.Transformation(input)

    """ Feature extraction stage """
    visual_feature = self.FeatureExtraction(input)
    visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))  # [b, c, h, w] -> [b, w, c, h]
    visual_feature = visual_feature.squeeze(3)

    """ Sequence modeling stage """
    if self.stages['Seq'] == 'BiLSTM':
        contextual_feature = self.SequenceModeling(visual_feature)
    else:
        contextual_feature = visual_feature  # for convenience. this is NOT contextually modeled by BiLSTM

    """ Prediction stage """
    if self.stages['Pred'] == 'CTC':
        prediction = self.Prediction(contextual_feature.contiguous())
    else:
        prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)

    return prediction
ShehneelAhmedKhan commented 8 months ago

Same issue here

txzskywalker commented 6 months ago

我也遇到该问题, model = model_pkg.Model(num_class=num_class, **network_params)

DivyangiPanchal commented 6 months ago

the "init(self, opt):" need opt object as a parameter whereas the "model = model_pkg.Model(num_class=num_class, **network_params)" passes network_params object as a parameter.

  1. I think, one need to redefine the init(self, opt) as shown in the custom_example.py file. i.e. def init(self, input_channel, output_channel, hidden_size, num_class):
  2. one need to update the config.yaml file with 'network_params' as well, just as mentioned in the custom_example.yaml