galeone / dynamic-training-bench

Simplify the training and tuning of Tensorflow models
Mozilla Public License 2.0
213 stars 31 forks source link

How to use the model the extract feature? #6

Closed guiyang882 closed 7 years ago

guiyang882 commented 7 years ago

Hi, I trained a SingleLayerCAE model use Cifar10 data sets. And I want to use this model to extract features. How to use your skeleton to test this model. Or else, Should I write program to finished it ?

thx

galeone commented 7 years ago

Since the model is just an encoding layer followed by a decoding layer, the features you want are the output of the encoding layer (a compact representation of the input).

To extract them you have to:

  1. Train the model
  2. Get the checkpoint path of the model you want to restore, let's call it checkpoint_path
  3. Define the model, load the weights from the checkpoint_path
  4. Extract the encoding tensor from the graph
  5. Use it

In short, something like this:

import sys
import tensorflow as tf
from dytb.train import train
from dytb.models.utils import variables_to_restore
from dytb.inputs import Cifar10
from dytb.models import SingleLayerCAE

# Train the model
model = SingleLayerCAE.SingleLayerCAE()
dataset = Cifar10.Cifar10()
info = train(model=model, dataset=dataset, hyperparameters={"epochs": 1})

# Get the checkpoint path
checkpoint_path = info["paths"]["best"]

# Define the model in the current graph (default)
# and define as input a placeholder
images = tf.placeholder(tf.float32, [None, 32, 32, 3])
is_training_, decode = model.get(images, train_phase=False)

# Extract the encoding tensor from the graph
graph = tf.get_default_graph()
encoding = graph.get_tensor_by_name("SingleLayerCAE/encode/Tanh:0")

saver = tf.train.Saver(variables_to_restore())
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    ckpt = tf.train.get_checkpoint_state(checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
else:
    print('[!] No checkpoint file found')
    sys.exit()

# use it: Read the image from somewhere and fill the placeholder
# I suppose that image_batch has the right shape (something like [1, 32,32,3])
encoded_representation = sess.run(encoding, feed_dict={images: image_batch})
galeone commented 7 years ago

FYI: I simplified the feature extraction part, see #13