Open ramon-garcia opened 1 year ago
ping
The example code you tried is not work in TF 2.x. Please refer other examples checked in https://github.com/tensorflow/rust/issues/309 .
@dskkato I am not sure if you are right. I tried writting the code with TF 2.x style (@tf.function
) and the same error appeared.
Did you try importing the PB file and creating a session with the imported graph?
In that code, the node name to read the Variable
are different between TF1 and TF2.
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b
loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')
init = tf.variables_initializer(tf.global_variables(), name='init')
definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)
# for debug
tf.train.write_graph(definition, directory, 'model.pbtxt', as_text=True)
As is in the model.pbtxt
, it seems that the node names are "w/Read/ReadVariableOp"
for "w"
and "b/Read/ReadVariableOp"
for "b
", respectively:
diff --git a/examples/regression.rs b/examples/regression.rs
index 5393b7ac..66f472d4 100644
--- a/examples/regression.rs
+++ b/examples/regression.rs
@@ -55,8 +55,8 @@ fn main() -> Result<(), Box<dyn Error>> {
let op_y = graph.operation_by_name_required("y")?;
let op_init = graph.operation_by_name_required("init")?;
let op_train = graph.operation_by_name_required("train")?;
- let op_w = graph.operation_by_name_required("w")?;
- let op_b = graph.operation_by_name_required("b")?;
+ let op_w = graph.operation_by_name_required("w/Read/ReadVariableOp")?;
+ let op_b = graph.operation_by_name_required("b/Read/ReadVariableOp")?;
// Load the test data into the session.
let mut init_step = SessionRunArgs::new();
Please note that I'm not sure that the canonical way to get the above node names.
Thanks, this is exactly what I was looking for.
The examples, for instance, examples/regression.rs include code to load a graph from a .pb file
The code works with a .pb file generated with Tensorflow 1.x for Python.
But if we run this Python file under Tensorflow 2.x after making some small changes for compatibility
But when trying to run the example, after initializing Tensorflow session, the following error appears
Error: {inner:0x55665fcf34b0, InvalidArgument: Requested tensor type does not match actual tensor type: Resource vs Float}