MarvinTeichmann / tensorflow-fcn

An Implementation of Fully Convolutional Networks in Tensorflow.
MIT License
1.1k stars 433 forks source link

Can we store the fine-tuned weight using tf.train.Saver()? #21

Closed robot010 closed 7 years ago

robot010 commented 7 years ago

Hi, I am fine-tuning your FCN toward my own dataset. But after tuning, I was doing stuff like 'conv1_1/filter:0','conv1_1/biases:0'; sess.run(filters) to extract features and save them into npy file. But since we have lots of filters here, listing all of them would be messy.

So I was wondering if we can do it with the tf.train.Saver, which allows us to store the trained Variables fast and clean. But later I notice that there is no "Variable" in the fcn8s_vgg file. And I am thinking that by using your original fcn8s_vgg.py file, we can't use the tf.train.Saver(), right?

I am new to tensorflow, please let me know your comments, thanks.

MarvinTeichmann commented 7 years ago

No, you can use this code and store variables using tensorflow saver. This is infact done in the KittiSeg repository. Take a closer look at demo.py, where a trained KittiSeg model, which is based on fcn8, is loaded.

I am using .npy to store the original vgg weights, as the .ckpt format is very unflexable. You cannot easily add, remove or rename variabels after storing the graph to .ckpt. This is however very useful to use this code as part of more advanced models. For you purpose saving to .ckpt should be fine.

robot010 commented 7 years ago

@MarvinTeichmann Thank you for your quick reply! I checked the KittiSeg, but all I can found is stuff like if 'TV_SAVE' in os.environ and os.environ['TV_SAVE']: tf.app.flags.DEFINE_boolean( 'save', True, , which looks like magic to me. And I did found a "tf.train.Saver()" in the eval.py in the tensorVision submodule. Should I install tensorVision as well? I am just curious why can we use Saver here since in the FCN8s file there is no Variable. Thank you

MarvinTeichmann commented 7 years ago

Checkpoints are loaded in core.load_weights(logdir, sess, saver), as utilized in demo.py:156. Checkpoints are saved in tensorvision/train.py:323. Hope this helps!