ria-com / nomeroff-net

Nomeroff Net. Automatic numberplate recognition system.
GNU General Public License v3.0
459 stars 159 forks source link

Unable to export OCR Keras model #126

Closed azhuchkov closed 1 year ago

azhuchkov commented 3 years ago

I'm trying to load OCR model using OpenCV DNN module, which requires conversion into Tensorflow frozen format.

Apparently NomeroffNet previously had special function that is being mentioned at one of the provided notebooks: from NomeroffNet.Base import convert_keras_to_freeze_pb (which is absent right now).

What I have right now (based on https://github.com/opencv/opencv/issues/16582):

    model = textDetector.MODEL
    model.summary()

    # Convert Keras model to ConcreteFunction
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    # frozen_func.graph.as_graph_def()

    # Save frozen graph from frozen ConcreteFunction to hard drive
    # tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
    #                   logdir=".",
    #                   name=OUTPUT_FILE,
    #                   as_text=False)

    graph_def = frozen_func.graph.as_graph_def()

    # Export frozen graph
    with tf.io.gfile.GFile(OUTPUT_FILE, 'wb') as f:
        f.write(graph_def.SerializeToString())

It works without errors or warnings, but I'm unable to load it then:

import cv2 as cv
import numpy as np

net = cv.dnn.readNetFromTensorflow('ocr.pb')

produces:

Traceback (most recent call last):
  File "/Users/user/Projects/ocr/test_ocr.py", line 4, in <module>
    net = cv.dnn.readNetFromTensorflow('ocr.pb')
cv2.error: OpenCV(4.5.1) /private/var/folders/nz/vv4_9tw56nv9k3tkvyszvwg80000gn/T/pip-req-build-oe0iat4a/opencv/modules/dnn/src/graph_simplifier.cpp:76: error: (-212:Parsing error) Input node with name model/dense2/Tensordot/Prod not found in function 'getInputNodeId'

And errors are changing after each export.

Any thoughts?

dimabendera commented 3 years ago

The freeze graph APIs not available in TensorFlow 2., but .save Method is already saving a .pb ready for inference. I managed to convert this way

import sys 
import os
import cv2
import numpy as np

NOMEROFF_NET_DIR = "../../"

# setup result model dirs
SAVED_MODEL_DIR  = os.path.join(NOMEROFF_NET_DIR, "./models/saved_model_ocr_eu")

sys.path.append(NOMEROFF_NET_DIR)

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import tensorflow as tf
from NomeroffNet.TextDetector import TextDetector

textDetector = TextDetector.get_static_module("eu")()
textDetector.load("latest")

textDetector.MODEL.summary()

# save pb model
tf.saved_model.save(textDetector.MODEL, SAVED_MODEL_DIR)

# load dnn
net = cv2.dnn.readNetFromTensorflow(SAVED_MODEL_DIR)

# get image and normalize
img = cv2.imread('../../examples/crop_np_images/RP70012.png')
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
cv2.resize(img, (128, 64))
img = img.astype(np.float32)
img -= np.amin(img)
img /= np.amax(img)
img = [[[h] for h in w] for w in img.T]

net.setInput(np.array(img))

# Runs a forward pass to compute the net output
networkOutput = net.forward()
print(networkOutput)
azhuchkov commented 3 years ago

Hmm... @dimabendera What version of OpenCV are you using?

I copy&pasted your example, updated paths to nomeroff-net and the image, and swapped the following lines:

img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = cv2.imread('../../examples/crop_np_images/RP70012.png')

and it results in:

cv2.error: OpenCV(4.5.1) /private/var/folders/nz/vv4_9tw56nv9k3tkvyszvwg80000gn/T/pip-req-build-oe0iat4a/opencv/modules/dnn/src/dnn.cpp:4037: error: (-215:Assertion failed) !empty() in function 'forward'

I think it waits for layers list as an argument, but despite of errors absense during model loading, net.getLayerNames() returns empty list.

I also tried to pass saved_model.pb file as an argument and got:

FAILED: ReadProtoFromBinaryFile(param_file, param). Failed to parse GraphDef file