JiahuiYu / generative_inpainting

DeepFill v1/v2 with Contextual Attention and Gated Convolution, CVPR 2018, and ICCV 2019 Oral
http://jiahuiyu.com/deepfill/
Other
3.26k stars 784 forks source link

How to run multiple images with multiple masks respectively #12

Open TrinhQuocNguyen opened 6 years ago

TrinhQuocNguyen commented 6 years ago

Thank you for your contribution,

At the moment, I have checked the test file and it only can run on 1 image/mask. I have tried to put code in the for loop, but I got error at this line : output = model.build_server_graph(input_image) output = model.build_server_graph(input_image) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 307, in build_server_graph config=None) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 50, in build_inpaint_net x = gen_conv(x, cnum, 5, 1, name='conv1') File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args return func(*args, *current_args) File "/home/ubuntu/trinh/generative_inpainting/inpaint_ops.py", line 45, in gen_conv activation=activation, padding=padding, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/convolutional.py", line 608, in conv2d return layer.apply(inputs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 671, in apply return self.call(inputs, args, **kwargs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 559, in call self.build(input_shapes[0]) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/convolutional.py", line 143, in build dtype=self.dtype) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 458, in add_variable trainable=trainable and self.trainable) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1203, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1092, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 425, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 394, in _true_getter use_resource=use_resource, constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 742, in _get_single_variable name, "".join(traceback.format_list(tb)))) ValueError: Variable inpaint_net/conv1/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

File "/home/ubuntu/trinh/generative_inpainting/inpaint_ops.py", line 45, in gen_conv activation=activation, padding=padding, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args return func(*args, **current_args) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 50, in build_inpaint_net x = gen_conv(x, cnum, 5, 1, name='conv1')

Do you have any ideas what I have done wrong? Thank you.

TrinhQuocNguyen commented 6 years ago

Oh thank you, I have found the answer: Just set the parameter reuse = tf.AUTO_REUSE output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE) The tensorflow will automatically understand and reuse the graph.

JiahuiYu commented 6 years ago

It would be even more efficient if you can build graph ONCE with placeholder and feed your images with sess.run. A related issue can be found #8.

TrinhQuocNguyen commented 6 years ago

Hello JiahuiYu, Thank you for your quick response. Did you mean sess.run ? I'm reading your source code to understand what you have done.

JiahuiYu commented 6 years ago

Sorry typo.

TrinhQuocNguyen commented 6 years ago

Hello JiahuiYu, Thank you for your response. I'm building the graph. In inpaint.yml file, at #loss legacy line. I have found that VGG_MOEL_FILE you have configured, I have read your paper, it did not mention transfer learning. So, I wonder whether we can use VGG16 network for transfer learning? Thank you for your concerns.

JiahuiYu commented 6 years ago

"We have not found perceptual loss (reconstruction loss on VGG features), style loss (squared Frobenius norm of Gram matrix computed on the VGG features) [21] and total variation (TV) loss bring noticeable improvements for image inpainting in our framework, thus are not used."

You will need to implement VGG16 perceptual loss by yourself.

TrinhQuocNguyen commented 6 years ago

Thank you for your fast response. I have used your pretrained model to apply transfer learning, it saved me a lot of time on a new training set. I am reading your paper again, I think it's a great paper.

TrinhQuocNguyen commented 6 years ago

Hello Jiahuiyu, Thank you for your awesome code, I have tried to modify and build the graph, but unfortunately I could not build it.
I have found that you have used build_server_graph function, but I don't understand it much. Could you please add some code you have built the graph and feed image by image into it? Thank you in advance.

TrinhQuocNguyen commented 6 years ago

Here is my code at the moment: use a for loop

# prepare folder path
    input_folder = args.test_dir + "/input"
    mask_folder = args.test_dir + "/mask"
    output_folder = args.test_dir + "/output_" + args.checkpoint_dir.split("/")[1] + "_" +datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # start sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True

    dir_files = os.listdir(input_folder)
    dir_files.sort()

    for file_inter in dir_files:
        sess = tf.Session(config=sess_config)

        base_file_name = os.path.basename(file_inter)

        image = cv2.imread(input_folder + "/" + base_file_name)
        mask = cv2.imread(mask_folder + "/" + base_file_name)

        assert image.shape == mask.shape

        h, w, _ = image.shape
        grid = 1
        image = image[:h//grid*grid, :w//grid*grid, :]
        mask = mask[:h//grid*grid, :w//grid*grid, :]
        print('Shape of image: {}'.format(image.shape))

        image = np.expand_dims(image, 0)
        mask = np.expand_dims(mask, 0)
        input_image = np.concatenate([image, mask], axis=2)

        input_image = tf.constant(input_image, dtype=tf.float32)
        output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
        output = (output + 1.) * 127.5
        output = tf.reverse(output, [-1])
        output = tf.saturate_cast(output, tf.uint8)
        # load pretrained model
        vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        assign_ops = []
        for var in vars_list:
            vname = var.name
            from_name = vname
            var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
            assign_ops.append(tf.assign(var, var_value))
        sess.run(assign_ops)
        print('Model loaded.')
        result = sess.run(output)

        # write to output folder
        cv2.imwrite(output_folder + "/" + base_file_name, result[0][:, :, ::-1])
        sess.close()
JiahuiYu commented 6 years ago

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     

    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    

    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         

        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      

        assert image.shape == mask.shape                                                                                                                                                                                                      

        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       

        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   

        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               

    print('Time total: {}'.format(time.time() - t)) 
TrinhQuocNguyen commented 6 years ago

Hi JiahuiYu , Thank you very much for your code and your contribution. I am so excited to check it out. Thank you again πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„ πŸ˜„

TrinhQuocNguyen commented 6 years ago

Hi JiahuiYu , wow, it worked. Thank you very much, you have saved me tons of time. 😍 😍 😍

JiahuiYu commented 6 years ago

No problem. :)

Bingmang commented 5 years ago

These codes should be added to the master branch 😍 😍 😍

TianLuluC commented 5 years ago

These codes should be added to the master branch 😍 😍 😍

@Bingmang Is the code added to the for loop of test.py? Thank you

JiahuiYu commented 5 years ago

I have made this thread open so others can have a reference.

zylxadz commented 4 years ago

@TrinhQuocNguyen Thank you very much for your discussions about training a new model! And could you give me more instructions to pre-train a model with transfer learning? Thanks a lot !

minushuang commented 4 years ago

great!

JeremyCJM commented 4 years ago

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     

    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    

    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         

        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      

        assert image.shape == mask.shape                                                                                                                                                                                                      

        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       

        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   

        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               

    print('Time total: {}'.format(time.time() - t)) 

Should be:

    output = model.build_server_graph(FLAGS, input_image_ph)                                                                                                                                                                         
arnavmehta7 commented 2 years ago

Hey I'm trying since days to customize some part, can you explain me how to access model and run model.summary() ???