weinman / cnn_lstm_ctc_ocr

Tensorflow-based CNN+LSTM trained with CTC-loss for OCR
GNU General Public License v3.0
497 stars 170 forks source link

How to batch inference? #71

Open igorvishnevskiy opened 2 years ago

igorvishnevskiy commented 2 years ago

Hello Jerod @Weinman. Does this platform support batch inference? How can I submit the list of images for prediction all at ones? Thank you.

weinman commented 2 years ago

Thanks for the question. I want to make sure I understand what you're trying to do. (While also confessing I will need to page many details about the code back into my brain.)

Do you just want (1) an easy way to map from a list of image paths to the predicted strings (i.e., piped to standard output), or do you actually (2) want the tensor of images packed together as input (i.e., you'll need the sequence lengths, too) and run a single GPU-parallel inference on them, producing the prediction sequences (i.e. as a sparse tensor)?

Both should be possible, but the second may require deriving some new client code for the Estimator that blends elements of validate.py and test.py.

igorvishnevskiy commented 2 years ago

Hello Jerod @weinman. Definitely the (2), the tensor of images packed together as input and run a single GPU-parallel inference on them, producing the prediction sequences (i.e. as a sparse tensor). Thank you for your help to understand how it could be achieved.

weinman commented 2 years ago

Well, it should be possible without too much trouble, I hope.

I note that both model_fn.predict_fn (used for validate.py to predict strings) and model_fn.evaluate_fn use the same underlying processing that can handle batched inputs (namely, calls to model_fn.get_output.

When I try a simple/dummy example taking the batched dataset from train.py and putting it through the predictor, i.e.,

# both from cnn_lstm_ctc_ocr/src
import model_fn 
import train
classifier = tf.estimator.Estimator( 
    model_fn=model_fn.predict_fn(
        None,None), # no lexicon or prior weight
    model_dir='/tmp/model' )
predictions = classifier.predict( input_fn=train._get_input )
results = next(predictions)

it still seems to gives me a single (unbatched) example.

I don't have more time right now to test this out further (determining what I'm probably doing wrong), but if you can get a batched tensor into model_fn.predict_fn I think it just might work as you would want. (Or at least be a start for doing so...)

igorvishnevskiy commented 2 years ago

Thank you so much Jerod @weinman. I will take over from here. Will check in PR as soon as I get it working.

weinman commented 2 years ago

As a quick follow-up before I completely lose track of this thread. The confusing issue that stymied me was that I thought the underlying tensors should be producing batched outputs, even in predict mode.

I thought this because the following test seems to indicate so:

import tensorflow as tf
from tensorflow.contrib import learn

import model_fn
import train

ds = train._get_input()
[features,labels]=tf.data.make_one_shot_iterator(ds).get_next()

mode = learn.ModeKeys.EVAL
logits, sequence_length = model_fn._get_image_info(features, mode)

predictions, log_probs = model_fn._get_output( logits, sequence_length, None, None ) # no lexicon or prior weight

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    [logit,pred] = sess.run([logits,predictions])
    print(logit.shape)
    print(pred[0].dense_shape)

Using the simple tfrecord file included in the repo, for me this produces:

(60, 32, 63)
[32 25]

which seems to indicate the full batch (size 32) is being produced.

I don't know if that helps, but I hope you figure out what you're looking for. I think it would be useful for others as well.