sherjilozair / char-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow
MIT License
2.64k stars 964 forks source link

How to run the trained model using flask api? #129

Open imsurinder90 opened 6 years ago

imsurinder90 commented 6 years ago

I created a route localhost:5050/predict to run the model with the given statement.

{
    "statement": "You helped the customer in troubleshooting the Cable issue and you asked "
}

but It gives error below error:

  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\flask_restful\__init__.py", line 595, in dispatch_request
    resp = meth(*args, **kwargs)
  File "D:\surinder\ds\test\text_classification_projects\char-rnn-tensorflow\wordpredict.py", line 45, in post
    saver = tf.train.Saver()
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1311, in __init__
    self.build()
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1320, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1345, in _build
    raise ValueError("No variables to save")
ValueError: No variables to save

Here is the code:

from __future__ import print_function
import os
from six.moves import cPickle
import tensorflow as tf
from model import Model

from flask import Flask, request
from flask_restful import Resource, Api

app = Flask(__name__)
api = Api(app)

params = {
    'save_dir': 'save',
    'prime': '',
    'n': 500,
    'sample': 2
}

def get_model():
    with open(os.path.join(params['save_dir'], 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(params['save_dir'], 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    return chars, vocab, Model(saved_args, training=False)

class predict(Resource):

    chars, vocab, model = get_model()

    def sample(self, statement, args, chars, vocab, model, saver, ckpt):
        with tf.Session() as sess:
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                result = model.sample(sess, chars, vocab, args['n'], statement, args['sample']).encode('utf-8')
                return result

    def post(self):
        statement = request.get_json(silent=True)['statement']
        result = None

        # tf.global_variables_initializer().run()
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(params['save_dir'])
        # with tf.Session() as sess:
        result = self.sample(
            statement, params, predict.chars,
            predict.vocab, predict.model, saver, ckpt
        ).decode('utf-8').split(".")[0]
        return {
            'statement': statement,
            'full_statement': result
        }

api.add_resource(predict, '/')

if __name__ == "__main__":
    app.run(debug=True)

Please help.