galeone / dynamic-training-bench

Simplify the training and tuning of Tensorflow models
Mozilla Public License 2.0
213 stars 31 forks source link

keras equivalent to extract features #13

Closed dineshbvadhia closed 7 years ago

dineshbvadhia commented 7 years ago

In keras, there is an example of how to extract features from an arbitrary intermediate layer with a VGG19 model at https://keras.io/applications/#usage-examples-for-image-classification-models. It is:

from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np

base_model = VGG19(weights='imagenet')
model = Model(inputs=base_model.input, outputs=base_model.get_layer('block4_pool').output)

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

block4_pool_features = model.predict(x)

What would the equivalent be with dytb?

galeone commented 7 years ago

Look at #6 : this is an example for a different model but the things to do are the same.

However since is something frequently requested I should add a method to every model abstract class, to simplify the layer extraction process and avoid making people rewriting the same code every time. I'll let this issue open until I add this method. Thank you for pointing this out!

galeone commented 7 years ago

Alright! I made it in https://github.com/galeone/dynamic-training-bench/commit/2719a08a1d205911845a42e0c2d65b587ae6759f

Since every model has its own evaluator object, I added this feature to the evaluators. In this way, we can make the feature extraction as simply as:

import tensorflow as tf

from dytb.models.predefined.VGG import VGG
from dytb.inputs.images import read_image

model = VGG()
image = tf.image.resize_bilinear(
        tf.expand_dims(
            read_image("images/nocat.png", channel=3, image_type="png"),
            axis=0), (32, 32))
features = model.evaluator.extract_features(
        checkpoint_path="../log/VGG/CIFAR-10_Momentum/best/",
        inputs=image,
        layer_name="VGG/pool1/MaxPool:0",
        num_classes=10)