tmbdev / clstm

A small C++ implementation of LSTM networks, focused on OCR.
Apache License 2.0
819 stars 223 forks source link

How to make predictions using python code #148

Open lorenzob opened 6 years ago

lorenzob commented 6 years ago

Hi, taking this example as a reference I wrote the following code:

import clstm
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.ndimage import filters

def decode2(pred, codec, threshold = .5):
    eps = filters.gaussian_filter(pred[:,0,0],2,mode='nearest')
    loc = (np.roll(eps,-1)>eps) & (np.roll(eps,1)>eps) & (np.eps<threshold)
    classes = np.argmax(pred,axis=1)[:,0]
    codes = classes[loc]
    chars = [chr(codec[c]) for c in codes]
    return "".join(chars)    

def decode1(pred, codec):
    classes = np.argmax(pred,axis=1)[:,0]
    print(classes)
    codes = classes[(classes!=0) & (np.roll(classes,1)==0)]
    #[print(int(c)) for c in codes]
    chars = [codec.decode(int(c)) for c in codes]
    return "".join(chars)

img_name="new-22_mcrop8.png"

img=cv2.imread(img_name, 0)
h=img.shape[0]
img = img.T.reshape(img.shape[0]*img.shape[1])
print("img.shape", img.shape)

net = clstm.load_net("model-180000.clstm")
print(clstm.network_info(net))

noutput=net.codec.size()
ninput=h
print("in, out: ",ninput,noutput)

#plt.imshow(img.reshape(h,-1))
#plt.show()

print("img.shape:", img.shape)
xs = np.array(img.reshape(-1,h,1),'f')

#plt.imshow(xs.reshape(-1,h).T,cmap=plt.cm.gray)
#plt.show()

print("xs.shape", xs.shape)
net.inputs.aset(xs)
net.forward()
pred = net.outputs.array()
print("pred.shape", pred.shape)

#plt.imshow(pred.reshape(-1,noutput).T, interpolation='none')
#plt.show()

codec = [net.codec.decode(i) for i in range(net.codec.size())]
print("codec: ", codec)

print(decode1(pred, net.codec))

This is the image:

new-22_mcrop8

and the expected output is obviously TRENTO. This is the model:

model-180000.clstm.zip

The model was trained with clstmocrtrain and works perfectly when I use clstmocr.

I had a look at the c++ code and I can see this:

raw() = -raw() + Float(1.0);
[...]
normalizer->normalize(image, raw);

before the forward call. Maybe this is the problem. Can I call these from python(I doubt)? Do I have to rewrite these in python? Is there a simpler way that I missed, like a predict() method? Should I have a look at Kraken?

Thanks for any suggestion.