jordipons / musicnn

Pronounced as "musician", musicnn is a set of pre-trained deep convolutional neural networks for music audio tagging.
ISC License
599 stars 87 forks source link

Inference with SavedModel format #16

Open shoegazerstella opened 3 years ago

shoegazerstella commented 3 years ago

Hi and many thanks for releasing this project!

I would like to convert the models in a SavedModel format and perform inference + feature extraction from there. I was able to convert one of the models to a saved_model.pb format, although I am not sure this is the right way to do it

import os
import tensorflow as tf

def convert():

    MODEL_DIR = 'MTT_musicnn/'
    trained_checkpoint_prefix = MODEL_DIR
    export_dir = os.path.join('export_dir', MODEL_DIR)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
        # Restore from checkpoint
        loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
        loader.restore(sess, trained_checkpoint_prefix)

        # Export checkpoint to SavedModel
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(sess,
                                            [tf.saved_model.TRAINING, tf.saved_model.SERVING],
                                            strip_default_attrs=True)
        builder.save()

This way I am able to load it like:

model = keras.models.load_model(model_path)

and it would look like this:

<tensorflow.python.training.tracking.tracking.AutoTrackable object at 0x7f0f6c231668>

This AutoTrackable object does not seem to have a predict method. So it fails when I provide 1 batch of data. Is there a clever way to make this work? Thanks a lot!