jfkirk / tensorrec

A TensorFlow recommendation algorithm and framework in Python.
Apache License 2.0
1.28k stars 222 forks source link

Getting error when loading on flask | Session issue #120

Open jaiswalvineet opened 5 years ago

jaiswalvineet commented 5 years ago

I am getting session management issue when try to use it in flask application, can you please help me

ValueError: Tensor("TensorSliceDataset:0", shape=(), dtype=variant) must be from the same graph as Tensor("Iterator:0", shape=(), dtype=resource).

Sample code structure


from io import StringIO
import flask
import pandas as pd
import tensorrec as tr
import tensorflow as tf

class ScoringService(object):
    model = None                
    model_path = 'model path'
    @classmethod
    def get_model(cls):
        if cls.model == None:
            cls.model = tr.TensorRec.load_model(model_path)

        return cls.model

    @classmethod
    def get_reco(cls, model,use_ft, ite_ft):
            tf.reset_default_graph()
            predictions = model.predict(use_ft, ite_ft)
        return predictions

    @classmethod
    def predict(cls, input):
        clf = cls.get_model()
        n_reco_test = cls.get_reco(clf, input,use_ft, ite_ft)
        return n_reco_test

# The flask app for serving predictions
app = flask.Flask(__name__)

@app.route('/invocations', methods=['POST'])
def transformation():
    data = None

    # Convert from CSV to pandas
    if flask.request.content_type == 'text/csv':
        data = flask.request.data.decode('utf-8')
        s = StringIO(data)
        data = pd.read_csv(s, header=None)
    else:
        return flask.Response(response='This predictor only supports CSV data', status=415, mimetype='text/plain')
    # Do the prediction
    predictions = ScoringService.predict(data)
    return flask.Response(response=predictions, status=200,mimetype='text/csv')
jfkirk commented 5 years ago

Hey @jaiswalvineet ! Thanks for reporting this. What is the purpose of the tf.reset_default_graph() inside of get_reco()? This is likely the culprit. Is the call to model.predict() the line that is failing?

jaiswalvineet commented 5 years ago

Hey @jfkirk , thanks for replying, So i just want to override the graphs so I will not face this error, but it does not work. I trained model and stored pickle file and getting this error when want to predict using pickle file, Yes, model.predict is failing to predict

jaiswalvineet commented 5 years ago

"POST /invocations HTTP/1.1" 500 291 "-" "AHC/2.0" File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 350, in _apply_op_helper g = ops._get_graph_from_inputs(_Flatten(keywords.values())) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5637, in _get_graph_from_inputs _assert_same_graph(original_graph_element, graph_element) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5573, in _assert_same_graph original_item)

jaiswalvineet commented 5 years ago

Seems tensorrec do something about session, which override default one, here the code says that tensorrec predictor do some magic

Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 2292, in wsgi_app response = self.full_dispatch_request() File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1815, in full_dispatch_request rv = self.handle_user_exception(e) File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1718, in handle_user_exception reraise(exc_type, exc_value, tb) File "/usr/local/lib/python3.5/dist-packages/flask/_compat.py", line 35, in reraise raise value File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1813, in full_dispatch_request rv = self.dispatch_request() File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1799, in dispatch_request return self.view_functionsrule.endpoint File "/opt/ml/code/predictor.py", line 143, in transformation predictions = ScoringService.predict(data) File "/opt/ml/code/predictor.py", line 105, in predict n_reco_test = cls.get_reco(clf, input, n_rec, user_features, ite_ft, u_map, i_map, i_order) File "/opt/ml/code/predictor.py", line 71, in get_reco predictions = model.predict(use_ft, ite_ft) File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 663, in predict item_features=item_features) File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 256, in _create_datasets_and_initializers for dataset in user_features_datasets] File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 256, in for dataset in user_features_datasets] File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/iterator_ops.py", line 308, in make_initializer dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_dataset_ops.py", line 1822, in make_iterator "MakeIterator", dataset=dataset, iterator=iterator, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 350, in _apply_op_helper g = ops._get_graph_from_inputs(_Flatten(keywords.values())) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5637, in _get_graph_from_inputs _assert_same_graph(original_graph_element, graph_element) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5573, in _assert_same_graph original_item))

jfkirk commented 5 years ago

Hey @jaiswalvineet -- Calling tf.reset_default_graph() after loading the model will blow away the graph and create a new one. I've reproduced locally and this will then cause the error you're seeing when you call predict() because the reset graph does not have the model.

If you remove the call to tf.reset_default_graph() do you see the same error?

jaiswalvineet commented 5 years ago

I removed the tf.reset_default_graph() bus still same error, seems the flask app does persist the graph, it's strange that its working if I directly run it but does not work if I call it from flask app, if its running one time for you from flask then refresh it on browser, you will get the error, if not then please share your flask app ...thanks in advance

gallmerci commented 5 years ago

@jaiswalvineet I got the same problem. The tf graph is not thread safe and needs to be globally defined and always reused when doing the prediction. Here is how I solved it (actually solution is coming from https://github.com/keras-team/keras/issues/2397#issuecomment-306687500):

I inserted a global model variable, loaded the model and stored the graph after model loading:

import tensorflow as tf
from tensorrec import TensorRec
model = TensorRec.load_model(directory_path=model_path)
graph = tf.get_default_graph()

and then in every call, I used the same graph as loaded in graph with a new session:

with graph.as_default():
    session = tf.Session()
    with session.as_default():
        # Now do the prediction
        predictions = ScoringService.predict(data)