jiangsutx / SRN-Deblur

Repository for Scale-recurrent Network for Deep Image Deblurring
http://www.xtao.website/projects/srndeblur/srndeblur_cvpr18.pdf
MIT License
709 stars 184 forks source link

Question about testing grayscale images!Could you please help me to see if my modified code is correct? #49

Closed nwpuqyj closed 4 years ago

nwpuqyj commented 4 years ago

Thank you for your code! My question is that Could you please help me to see if my modified code is correct? I modified the "test" function in model.py to test grayscale images. The code can be run after modification. But I'm not sure if the details are correct....The modified part of the code has been marked with "Here!". Looking forward to your reply!

    def test(self, height, width, input_path, output_path):
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        imgsName = sorted(os.listdir(input_path))

        H, W = height, width
        inp_chns = 3 if self.args.model == 'color' else 1
        self.batch_size = 1                                                    #Here!
        inputs = tf.placeholder(shape=[self.batch_size, H, W, inp_chns], dtype=tf.float32)
        outputs = self.generator(inputs, reuse=False)

        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))

        self.saver = tf.train.Saver()
        self.load(sess, self.train_dir, step=3000)

        for imgName in imgsName:
            blur = scipy.misc.imread(os.path.join(input_path, imgName))
            blur = np.expand_dims(blur, 2)                           #Here!
            h, w, c = blur.shape
            # make sure the width is larger than the height
            rot = False
            if h > w:
                blur = np.transpose(blur, [1, 0, 2])
                rot = True
            h = int(blur.shape[0])
            w = int(blur.shape[1])
            resize = False
            if h > H or w > W:
                scale = min(1.0 * H / h, 1.0 * W / w)
                new_h = int(h * scale)
                new_w = int(w * scale)
                blur = scipy.misc.imresize(blur, [new_h, new_w], 'bicubic')
                resize = True
                blurPad = np.pad(blur, ((0, H - new_h), (0, W - new_w), (0, 0)), 'edge')
            else:
                blurPad = np.pad(blur, ((0, H - h), (0, W - w), (0, 0)), 'edge')
            blurPad = np.expand_dims(blurPad, 0)
            if self.args.model != 'color':
                blurPad = np.transpose(blurPad, (3, 1, 2, 0))

            start = time.time()
            deblur = sess.run(outputs, feed_dict={inputs: blurPad / 255.0})
            duration = time.time() - start
            print('Saving results: %s ... %4.3fs' % (os.path.join(output_path, imgName), duration))
            res = deblur[-1]
            if self.args.model != 'color':
                res = np.transpose(res, (3, 1, 2, 0))
            res = im2uint8(res[0, :, :,0])                            #Here!
            # crop the image into original size
            if resize:
                res = res[:new_h, :new_w]                               #Here!
                res = scipy.misc.imresize(res, [h, w], 'bicubic')
            else:
                res = res[:h, :w]                                                 #Here!
            if rot:
                res = np.transpose(res, [1, 0])                           #Here!
            scipy.misc.imsave(os.path.join(output_path, imgName), res)
jiangsutx commented 4 years ago

Sorry I cannot check code line-by-line.

  1. Make sure you are using gray model.

  2. And make sure the input to this line: deblur = sess.run(outputs, feed_dict={inputs: blurPad / 255.0}) blurPad should have the correct shape, that is [batch, height, width, channel = 1]. If so, the inference should be correct.

  3. Sometimes, your image seems gray-scale, but it has actually 3-channels. So please make sure, the shapes of each step are all as you expected.

Hope it helps.

nwpuqyj commented 4 years ago

@jiangsutx Thank you for your reply! As you said ,in the code: deblur = sess.run(outputs, feed_dict={inputs: blurPad / 255.0}) I have set channel to 1 to satisfy [batch, height, width, channel = 1],and what should batch_size be set to?I see your original code: self.batch_size = 1 if self.args.model == 'color' else 3 When i set batch_size to 3 to satisfy [3,h,w,1],the code has something wrong..however,when i set batch_size to 1,[batch, height, width, channel] satisfies [1,h,w,1],and the code can work!Both inputs and blurPad have the same dimensions [3,h,w,1] or [1,h,w,1],but i don't know how to choose the value of batch_size..

jiangsutx commented 4 years ago
  1. When doing testing, all you want is to test one image, so batch_size should be 1.
  2. When testing on color images: a) color model takes RGB channel at once, so shape is [1, height. width, 3] as input. b) gray model take single channel image as input, and RGB has 3 channels, so we separate it into 3 images and put them in batch channel. so input size is [3, height, width, 1].
  3. In your case, you only want to test one gray image using gray model, so input should be [1, height, width, 1].

Hope it is clear.

nwpuqyj commented 4 years ago

@jiangsutx Thanks a lot! it is very clear!