omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 284 forks source link

Saving weights of model and calculation of embeddings #48

Closed amanattrish closed 4 years ago

amanattrish commented 4 years ago

I was following your article in link here Triplet Loss and Online Triplet Mining in TensorFlow. I wanted to save the model weights that are defined in embedImages(Images) method. Later I wanted to use these weights to calculate the embeddings of given input image. I tried running this by running new session as given below ; im = cv2.imread(img_path) img = np.reshape(im, (1,374,388,3)) with tf.Session() as session: result = session.run(embedded_images, feed_dict={Images: img}) print(result)

But didn't get any useful results? Can you share idea how to do it. @omoindrot

amanattrish commented 4 years ago

For those who are stuck on this problem, here is what I tried;

#insert this during training session
saver = tf.train.Saver()
save_path = saver.save(sess, "path_to_save_model/model_name.ckpt")

Now call the saved model in session to calculate embedding;

img = cv2.imread('img_path.png')
with tf.Session() as new_sess:
    #don't forget to define embedded_images again
    result = new_sess.run(embedded_images, feed_dict={Images: img})
    print(result)    #result is our desired embedding

If somebody has even better solution please share!