Closed zhuantouer closed 7 years ago
def predict(x_test,m_checkpoint_dir):
checkpoint_file = tf.train.latest_checkpoint(m_checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
return all_predictions
In my project, this code can be used to predict file type.
@mylamour thanks, solved.
Hi, Denny I write a OOP schema cnn model, but there is some error when predict. I know that if define and write the code like
c
style in one file, it is easy to restore the value, but in OOP schema, it is something wrong. Here is my class:and this is my train file:
and this is my predict file:
error is:
Do you know how to load the Variable like OOP schema? Thanks!