keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.47k forks source link

How to convert Tensorflow model to Keras? #8026

Closed alexander-rakhlin closed 7 years ago

alexander-rakhlin commented 7 years ago

Hello all,

This is actually not an issue, but question. Is there a tool or manual on how to convert native Tensorflow model (Inception) to Keras?

mahnerak commented 7 years ago
  1. If you want to convert the network architecture then you need to rewrite the code with Keras.
  2. If you want to convert only weights (suppose you have code for the same model), you have to create model with random weights (you can find InceptionV3 in keras.applications) then read the TensorFlow .ckpt file with tf.train.NewCheckpointReader then call set_weights() method for each corresponding layer.

I think this repo may be very helpful. Here are the both code for network [1] and weight extraction and loading script [2]: https://github.com/myutwo150/keras-inception-resnet-v2

alexander-rakhlin commented 7 years ago

@mahnerak many thanks! I managed to convert model in my own way, but the script you suggest does it more systematically.

For those interested, I instantiate the model in Tensorflow:

    image = tf.placeholder(tf.float32, shape=(299, 299, 3))
    preprocessed = tf.multiply(tf.subtract(tf.expand_dims(image, 0), 0.5), 2.0)

    with slim.arg_scope(inception.inception_v3_arg_scope()):
        logits, end_points = inception.inception_v3(
            preprocessed, num_classes=1001, is_training=False)

    # Run computation
    saver = tf.train.Saver(slim.get_model_variables())
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=tf.train.Scaffold(saver=saver),
        checkpoint_filename_with_path="inception_v3.ckpt",
        master="")

List all model variables:


    restore_vars = [
        var for var in slim.get_model_variables()
        if var.name.startswith('InceptionV3/')
    ]

Extract:


   with tf.train.MonitoredSession(session_creator=session_creator) as sess:
        tf_layers = {}
        mode = ""
        for v in restore_vars:
            w = sess.run(v)
....
mrgloom commented 5 years ago

Is it possible to get weights from .pb frozen graph?