affinelayer / pix2pix-tensorflow

Tensorflow port of Image-to-Image Translation with Conditional Adversarial Nets https://phillipi.github.io/pix2pix/
MIT License
5.07k stars 1.3k forks source link

How to get predict_fake images from exported models #203

Closed 4iue0 closed 2 years ago

4iue0 commented 3 years ago

Hi, I wanna get predict_fake images from exported models to fix outputs. in pix2pix.py line 476~ return Model( predict_real=predict_real, predict_fake=predict_fake, discrim_loss=ema.average(discrim_loss), discrim_grads_and_vars=discrim_grads_and_vars, gen_loss_GAN=ema.average(gen_loss_GAN), gen_loss_L1=ema.average(gen_loss_L1), gen_grads_and_vars=gen_grads_and_vars, outputs=outputs, train=tf.group(update_losses, incr_global_step, gen_train), ) There is predict_fake. so I think I can get it.

Please tell me, how to get predict_fake images from exported models.

I use the bellow code to generate images from reported models.

from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf import numpy as np import argparse import json import base64

parser = argparse.ArgumentParser() parser.add_argument("--model_dir", required=True, help="directory containing exported model") parser.add_argument("--input_file", required=True, help="input PNG image file") parser.add_argument("--output_file", required=True, help="output PNG image file") a = parser.parse_args()

def main(): with open(a.input_file, "rb") as f: input_data = f.read() input_instance = dict(input=base64.urlsafe_b64encode(input_data).decode("ascii"), key="0") input_instance = json.loads(json.dumps(input_instance)) with tf.Session() as sess: saver = tf.train.import_meta_graph(a.model_dir + "/export.meta") saver.restore(sess, a.model_dir + "/export") input_vars = json.loads(tf.get_collection("inputs")[0]) output_vars = json.loads(tf.get_collection("outputs")[0]) input = tf.get_default_graph().get_tensor_by_name(input_vars["input"]) output = tf.get_default_graph().get_tensor_by_name(output_vars["output"]) input_value = np.array(input_instance["input"]) output_value = sess.run(output, feed_dict={input: np.expand_dims(input_value, axis=0)})[0] output_instance = dict(output=output_value.decode("ascii"), key="0") b64data = output_instance["output"] b64data += "=" * (-len(b64data) % 4) output_data = base64.urlsafe_b64decode(b64data.encode("ascii")) with open(a.output_file, "wb") as f: f.write(output_data)

if __name__ == '__main__': main()

Thank you!

4iue0 commented 2 years ago

Sorry, I just understood it is unable. Thank you.