tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
812 stars 200 forks source link

"Table not initialized" when loading model in Java #183

Open WittyLLL opened 3 years ago

WittyLLL commented 3 years ago

I am trying to use the tensorflow model in java,I convert a text classification model (with tf.lookup) to fomat .pb and want to load it in JAVA.But got "Table not initialized" error.

2021-01-04 14:00:10.713588: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: Table not initialized.
     [[{{node graph/hash_table_Lookup/LookupTableFindV2}}]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at org.tensorflow.Session$Runner.run(Session.java:276)
    at ctest.Ttest.predict(Ttest.java:32)
    at ctest.Ttest.main(Ttest.java:13)

here is my code: In PYTHON

import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.ops.lookup_ops import HashTable, KeyValueTensorInitializer

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
OUTPUT_FOLDER = ''
OUTPUT_NAME = 'hash_table.pb'
OUTPUT_NAMES = ['graph/output', 'init_all_tables']

def build_graph():
    d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
    init = KeyValueTensorInitializer(list(d.keys()), list(d.values()))
    hash_table = HashTable(init, default_value=-1)
    data = tf.placeholder(tf.string, (None,), name='data')
    values = hash_table.lookup(data)
    output = tf.identity(values * 2, 'output')

def freeze_graph():
    with tf.Graph().as_default() as graph:
        with tf.name_scope('graph'):
            build_graph()

        with tf.Session(graph=graph) as sess:
            sess.run(tf.tables_initializer())
            print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))
            frozen_graph = convert_variables_to_constants(sess, sess.graph_def, OUTPUT_NAMES)
            tf.train.write_graph(frozen_graph, OUTPUT_FOLDER, OUTPUT_NAME, as_text=False)

def load_frozen_graph():
    with open(os.path.join(OUTPUT_FOLDER, OUTPUT_NAME), 'rb') as f:
        output_graph_def = tf.GraphDef()
        output_graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(output_graph_def, name='')
        with tf.Session(graph=graph) as sess:
            try:
                sess.run(graph.get_operation_by_name('init_all_tables'))
            except KeyError:
                pass
            print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))

if __name__ == '__main__':
    freeze_graph()
    load_frozen_graph()

In JAVA

package ctest;

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.nio.file.Files;
import java.nio.file.Paths;

public class Ttest {
    public static void main(String[] args) throws Exception {
        predict();
    }
    public static void predict() throws Exception {
        try (Graph graph = new Graph()) {
            graph.importGraphDef(Files.readAllBytes(Paths.get(
                    "/opt/resources/hash_table.pb"
            )));
            try (Session sess = new Session(graph)) {
                byte[][] matrix = new byte[1][];
                matrix[0] = "a".getBytes("UTF-8");
                Tensor< ? > out = sess.runner()
                        .feed("graph/data:0", Tensor.create(matrix)).fetch("graph/output:0").run().get(0);
                float[][] output = new float[1][(int) out.shape()[1]];
                out.copyTo(output);
                for(float i:output[0])
                    System.out.println(i);

            }
        }
    }
}

Any suggestions would be greatly appreciated.

Craigacp commented 3 years ago

What happens if you run the init_all_tables target from Java after you've built the session?

WittyLLL commented 3 years ago

What happens if you run the init_all_tables target from Java after you've built the session?

I am not sure how to run the init_all_tables target from Java, could you provide me with a demo?It will be very helpful to me.

Craigacp commented 3 years ago

sess.runner().addTarget("init_all_tables").run()

WittyLLL commented 3 years ago

sess.runner().addTarget("init_all_tables").run()

thank you for your patience,i try to add the init_all_tables target in Java,But it seems does not work.

here is my code

package ctest;

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.nio.file.Files;
import java.nio.file.Paths;

public class Ttest {
    public static void main(String[] args) throws Exception {
        predict();
    }
    public static void predict() throws Exception {
        try (Graph graph = new Graph()) {
            graph.importGraphDef(Files.readAllBytes(Paths.get(
                    "/opt/resources/hash_table.pb"
            )));
            try (Session sess = new Session(graph)) {
                byte[][] matrix = new byte[1][];
                matrix[0] = "a".getBytes("UTF-8");
                Tensor< ? > out = sess.runner().addTarget("init_all_tables")
                        .feed("graph/data:0", Tensor.create(matrix)).fetch("graph/output:0").run().get(0);
                float[][] output = new float[1][(int) out.shape()[1]];
                out.copyTo(output);
                for(float i:output[0])
                    System.out.println(i);

            }
        }
    }
}

here is the error

2021-01-18 13:44:39.626529: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2021-01-18 13:44:39.670960: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: Table not initialized.
     [[{{node graph/hash_table_Lookup/LookupTableFindV2}}]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at org.tensorflow.Session$Runner.run(Session.java:276)
    at ctest.Ttest.predict(Ttest.java:22)
    at ctest.Ttest.main(Ttest.java:11)
Craigacp commented 3 years ago

Try running it separately before you try to access the table, but within the same session.