taki0112 / UGATIT

Official Tensorflow implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation (ICLR 2020)
MIT License
6.17k stars 1.04k forks source link

Name of output_node_names for .pb #70

Closed DiMiTriFrog closed 4 years ago

DiMiTriFrog commented 4 years ago

I'm trying to convert the checkpoint to .pb graph of Tensorflow.

python3 freeze_graph.py --input_meta_graph=checkpoint/UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing/UGATIT.model-1000000.meta --input_checkpoint=checkpoint/UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing --output_graph=selfie2anime.pb --input_binary=True --output_node_names="???"

What do I have to write in output_node_names?

Thank you!

FantasyJXF commented 4 years ago

@DiMiTriFrog generator_B/Tanh might do, the pb model is about a.15GB, but the inference processure was not Okay now.

FantasyJXF commented 4 years ago

It so strange that when I look into the pb file, it includes the trainning file strings like that:

            string_val: "./dataset/selfie2anime/trainB/3214.png"
            string_val: "./dataset/selfie2anime/trainB/0527.png"
            string_val: "./dataset/selfie2anime/trainB/3057.png"
            string_val: "./dataset/selfie2anime/trainB/2055.png"
            string_val: "./dataset/selfie2anime/trainB/0388.png"
            string_val: "./dataset/selfie2anime/trainB/1329.png"
            string_val: "./dataset/selfie2anime/trainB/2474.png"
            string_val: "./dataset/selfie2anime/trainB/2638.png"
            string_val: "./dataset/selfie2anime/trainB/1964.png"
            string_val: "./dataset/selfie2anime/trainB/0124.png"
            string_val: "./dataset/selfie2anime/trainB/1755.png"
          }
        }
      }
    }
    node_def {
      name: "TensorSliceDataset"
      op: "TensorSliceDataset"
      input: "TensorSliceDataset/tensors_1/component_0:output:0"
      attr {
        key: "Toutput_types"
        value {
          list {
            type: DT_STRING
          }
        }
      }
DiMiTriFrog commented 4 years ago

generator_B/Tanh

Thanks for your response!

I have a problem when I execute the full command

python3 freeze_graph.py --input_meta_graph=checkpoint/UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing/UGATIT.model-1000000.meta --input_checkpoint=checkpoint/UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing --output_graph=selfie2anime.pb --input_binary=True --output_node_names="generator_B/Tanh"

->

Traceback (most recent call last): File "/Users/user/Desktop/freeze_graph.py", line 491, in run_main() File "/Users/user/Desktop/freeze_graph.py", line 487, in run_main app.run(main=my_main, argv=[sys.argv[0]] + unparsed) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/platform/app.py", line 40, in run _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef) File "/Users/user/anaconda3/lib/python3.7/site-packages/absl/app.py", line 300, in run _run_main(main, args) File "/Users/user/anaconda3/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "/Users/user/Desktop/freeze_graph.py", line 486, in my_main = lambda unused_args: main(unused_args, flags) File "/Users/user/Desktop/freeze_graph.py", line 378, in main flags.saved_model_tags, checkpoint_version) File "/Users/user/Desktop/freeze_graph.py", line 361, in freeze_graph checkpoint_version=checkpoint_version) File "/Users/user/Desktop/freeze_graph.py", line 154, in freeze_graph_with_def_protos input_meta_graph_def, clear_devices=True) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1449, in import_meta_graph kwargs)[0] File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1473, in _import_meta_graph_with_return_elements kwargs)) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/meta_graph.py", line 857, in import_scoped_meta_graph_with_return_elements return_elements=return_elements) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func return func(*args, **kwargs) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/importer.py", line 400, in import_graph_def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) File "/Users/user/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/importer.py", line 160, in _RemoveDefaultAttrs op_def = op_dict[node.op] KeyError: 'IteratorGetDevice'

I would like create a .pb and check how incremented the speed of transform images and know if this is viable for a api serving.

FantasyJXF commented 4 years ago
import os
import tensorflow as tf
from tensorflow import graph_util
import argparse
import tensorflow.contrib
tf.contrib.resampler

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', required=True)
parser.add_argument('--output', dest='output', required=True)

args = parser.parse_args()

meta_graph = [meta for meta in os.listdir(args.checkpoint_dir) if '.meta' in meta]
assert (len(meta_graph) > 0)
print('meta_graph is ', meta_graph)

sess = tf.Session()
saver = tf.train.import_meta_graph(os.path.join(args.checkpoint_dir, meta_graph[0]), clear_devices=True)
saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
graph = tf.get_default_graph()

input_graph_def = graph.as_graph_def()

output_node_names = 'generator_B/Tanh'
output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def,output_node_names.split(","))

with tf.gfile.GFile(args.output, "wb") as f:
    f.write(output_graph_def.SerializeToString())
sess.close()

That's the code I used to convert the meta to pb with TensorFlow 1.8.0. I use Netron to visualize the model, it's all alright.

But I couldn't inference with the pb model file, runs out the following error message

  File "/anaconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/anaconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1088, in _run
    subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
  File "/anaconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py", line 132, in as_numpy_dtype
    return _TF_TO_NP[self._type_enum]
KeyError: 20

Hoping you could solve the problem

DiMiTriFrog commented 4 years ago

I have obtain the .pb file thanks to your code!

Thank you very much!

tankfly2014 commented 4 years ago

I have obtain the .pb file thanks to your code!

Thank you very much!

where's your code. please public.

rtolps commented 3 years ago

Hi! I made a script with DiMiTriFrog's code and ran it but I'm getting this error in line 16:

assert (len(meta_graph) > 0)
AssertionError

What could be causing this?