tks10 / segmentation_unet

Semantic segmentation using U-NET
46 stars 14 forks source link

saving and restoring the trained model #4

Open katiejx opened 3 years ago

katiejx commented 3 years ago

Hi @tks10 and @PINTO0309 ! I hope your safe and healthy

Thank you very much for sharing your code, it was extremely helpful to me to understand and implement Unet for image segmentation in my set of images. However, I am stuck on a particular point and would be grateful if you could help me continue.

I see in repoter.py the function save_model(), but then it is not called in main.py. I try to save my trained model, and then restore it in a different script, in order to use it on a new image. When I use tf.train.Checkpoint I get the following error

*** ValueError: `Checkpoint` was expecting a trackable object (an object derived from `TrackableBase`), got <util.model.Model object at 0x000002BE7149B910>. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.

and when I use tf.train.Saver it does proceed to saving a model, but when restoring it (with tf.train.import_meta_graph), I cannot reproduce the output. Actually, it is not very clear to me what variable or operation of the saved model I should restore and sess.run() with feed_dict in order to get model_unet.output. When I try the variable 'conv2d_18/bias/Adam_1:0', which is the last variable of the model, I get a 6x1 array not a 128x128 array.

I would be more than grateful if you could please help me with this.