Closed evanlym closed 4 years ago
Hi, thanks for your attention.
We have tested two different discriminators in shadow removal.
One is the default patch-based discriminator in this repo. Another is the fully-connected discriminator.
However, we have not found any difference in this task.
you can try to use the fc_discriminator
by the following code:
def fc_discriminator(discrim_inputs,discrim_targets):
n_layers = 4
layers = []
input = tf.concat([discrim_inputs, discrim_targets], axis=3)
# 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
# layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
with tf.variable_scope("layer_1"):
convolved = conv(input, channel, stride=2)
rectified = lrelu(convolved, 0.2)
layers.append(rectified)
# layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
# layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
# layer_4: [batch, 32, 32, ndf * 4] => [batch, 16, 16, ndf * 8]
# layer_5: [batch, 16, 16, ndf * 4] => [batch, 8, 8, ndf * 8]
for i in range(n_layers):
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
out_channels = channel * min(2**(i+1), 8)
convolved = conv(layers[-1], out_channels, stride=2)
normalized = batchnorm(convolved)
rectified = lrelu(normalized, 0.2)
layers.append(rectified)
# layer_5: [batch, 8, 8, ndf * 8] => [batch, 1]
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
# avg pooling
avg = tf.reduce_mean(rectified, axis=[1,2])
acg = tf.contrib.layers.flatten(avg)
convolvedX = tf.contrib.layers.fully_connected(acg,1,activation_fn=None)
convolved = tf.nn.sigmoid(convolvedX)
layers.append(convolved)
return layers[-1],layers
This works, thank you!
Hi,
I tried to load the srd+ pretrained model to the DHAN (downloaded from this link: https://drive.google.com/uc?id=1rEIWWLwEpbZGPyFUc9jSIQr78ZeQy5eZ), however, I get this error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.
This error doesn't occur when I only load the generator. So it seems that the discriminator model in the code doesn't match with the model from the checkpoint. If so, could you please update the pretrained model so that they match?
Thanks a lot.