patlevin / tfjs-to-tf

A TensorFlow.js Graph Model Converter
MIT License
139 stars 18 forks source link

TypeError: graph must be a tf.Graph #9

Closed allo- closed 4 years ago

allo- commented 4 years ago

When I upgraded from v0.5.0 to current master, I now get the error:

Traceback (most recent call last):
  File "./virtual_webcam.py", line 211, in <module>
    sess = tf.compat.v1.Session(graph=graph)
  File "venvs/virtual_webcam_background/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1587, in __init__
    super(Session, self).__init__(target, graph, config=config)
  File "venvs/virtual_webcam_background/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 657, in __init__
    raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
TypeError: graph must be a tf.Graph, but got <class 'tensorflow.core.framework.graph_pb2.GraphDef'>

The model is body-pix (mobilenet) and I load it like this:

import tfjs_graph_converter.api as tfjs_api
graph = tfjs_api.load_graph_model(model_path) 

What do I need to change for the current version?

patlevin commented 4 years ago

The current master is not a finished release yet. I will look into the issue, but there's a lot of changes coming (including publishing to PyPi as per your suggestion 😀).

Meanwhile your particular issue can be solved by using

sess = tf.compat.v1.Session(graph=graph.as_graph_def())
allo- commented 4 years ago

Okay. I just wanted to test to add the master branch (or maybe a working commit) to requirements until there is a new release.

allo- commented 4 years ago

Isn't the problem the other way round? I need a graph, but the function returns a GraphDef.

I already tried to use the GraphDef like this https://stackoverflow.com/a/48318045 but then I got other errors from inside tensorflow.

I guess I'll wait for the release then, before adding workarounds that may need to be changed by the release date.

patlevin commented 4 years ago

Isn't the problem the other way round? I need a graph, but the function returns a GraphDef.

You're absolutely right - I got that backwards. The solution given on stackoverflow should work, though.

graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
    tf.graph_util.import_graph_def(graph_def, name='')
    # ...

I will, however, look into it anyway.

allo- commented 4 years ago

It works for me like this:

graph_def = tfjs_api.load_graph_model(model_path)

sess = tf.compat.v1.Session()
tf.graph_util.import_graph_def(graph_def, name='')
graph = tf.compat.v1.get_default_graph()

input_tensor_names = tfjs_util.get_input_tensors(graph)
output_tensor_names = tfjs_util.get_output_tensors(graph)
patlevin commented 4 years ago

Good to hear it works👍. It's issues like this that made me include type hints throughout the public API.

I will add some more usage examples to the documentation - both for TF v1 and TF v2 APIs.