tensorflow / rust

Rust language bindings for TensorFlow
Apache License 2.0
5.17k stars 422 forks source link

Cannot import graph generated with Tensorflow 2.x #388

Open ramon-garcia opened 1 year ago

ramon-garcia commented 1 year ago

The examples, for instance, examples/regression.rs include code to load a graph from a .pb file

    let mut graph = Graph::new();
    let mut proto = Vec::new();
    File::open(filename)?.read_to_end(&mut proto)?;
    graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;

The code works with a .pb file generated with Tensorflow 1.x for Python.

import os
import tensorflow as tf

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

But if we run this Python file under Tensorflow 2.x after making some small changes for compatibility

import os
import tensorflow as tf
import tensorflow.compat.v1 as tf1
tf1.disable_eager_execution()

x = tf1.placeholder(tf.float32, name='x')
y = tf1.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform_initializer(minval=-1.0, maxval=1.0)(shape=[1]), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf1.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf1.variables_initializer(tf1.global_variables(), name='init')

definition = tf1.Session().graph_def
directory = 'examples/regression'
tf.io.write_graph(definition, directory, 'model.pb', as_text=False)

But when trying to run the example, after initializing Tensorflow session, the following error appears Error: {inner:0x55665fcf34b0, InvalidArgument: Requested tensor type does not match actual tensor type: Resource vs Float}

ramon-garcia commented 1 year ago

ping

dskkato commented 1 year ago

The example code you tried is not work in TF 2.x. Please refer other examples checked in https://github.com/tensorflow/rust/issues/309 .

ramon-garcia commented 1 year ago

@dskkato I am not sure if you are right. I tried writting the code with TF 2.x style (@tf.function) and the same error appeared.

Did you try importing the PB file and creating a session with the imported graph?

dskkato commented 1 year ago

In that code, the node name to read the Variable are different between TF1 and TF2.

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

# for debug
tf.train.write_graph(definition, directory, 'model.pbtxt', as_text=True)

As is in the model.pbtxt, it seems that the node names are "w/Read/ReadVariableOp" for "w" and "b/Read/ReadVariableOp" for "b", respectively:

diff --git a/examples/regression.rs b/examples/regression.rs
index 5393b7ac..66f472d4 100644
--- a/examples/regression.rs
+++ b/examples/regression.rs
@@ -55,8 +55,8 @@ fn main() -> Result<(), Box<dyn Error>> {
     let op_y = graph.operation_by_name_required("y")?;
     let op_init = graph.operation_by_name_required("init")?;
     let op_train = graph.operation_by_name_required("train")?;
-    let op_w = graph.operation_by_name_required("w")?;
-    let op_b = graph.operation_by_name_required("b")?;
+    let op_w = graph.operation_by_name_required("w/Read/ReadVariableOp")?;
+    let op_b = graph.operation_by_name_required("b/Read/ReadVariableOp")?;

     // Load the test data into the session.
     let mut init_step = SessionRunArgs::new();
dskkato commented 1 year ago

Please note that I'm not sure that the canonical way to get the above node names.

ramon-garcia commented 1 year ago

Thanks, this is exactly what I was looking for.