Open ravi02512 opened 4 years ago
Maybe the authors will correct me, but high resolutions is a big problem for GANs. Even if you alter the networks to get correct sizes, you will not be able to train it. The solution is to either use small crops of you images (but that depends on kind of images) or use progressive growing GANs. AnoGAN with pGANs: https://arxiv.org/pdf/1905.11034.pdf
I was able to update the code to use higher resolution images, although I could only test images with power of 2, squared sizes and trying to train on a dataset resized to have 256x256 images, 10516 training, 3010 validation and 1505 test images, with batch_size=1 required to have more than 24GB of RAM on the GPU, which I did not have. I did try a dataset resized to be of image size 192x192 and that did not crash, but my GPU also didn't have enough RAM to handle it. Size 128x128 did work with batch_size=8.
As far as I can remember, the updates consisted of changing all of the hardcoded values of 64 (across anomaly_detection, wgangp_64x64, z_encoding_izif and img_loader files) to a variable IMAGE_DIM (on my case), set DIM = IMAGE_DIM, ZDIM = 2*IMAGE_DIM, OUTPUT_DIM = IMAGE_DIM*IMAGE_DIM*1, I think CRITIC_ITERS should also be updated to a higher value depending on IMAGE_DIM as well. After that most modifications happened on GoodGenerator, Good Discriminator and both Encoders methods.
def GoodGenerator(n_samples, noise=None, rand_sampling=RAND_SAMPLING, dim=DIM, nonlinearity=tf.nn.relu, z_out=False, is_training=True, reuse=None):
with tf.variable_scope('Generator', reuse=reuse):
if noise is None:
if rand_sampling == 'unif':
noise = tf.random_uniform([n_samples, ZDIM], minval=-1., maxval=1.)
elif rand_sampling == 'normal':
noise = tf.random_normal([n_samples, ZDIM])
factor2 = int(IMAGE_DIM/8)
factor = int(factor2/2)
#factor = 4
#factor2 = 8
output = lib.ops.linear.Linear('Generator.Input', ZDIM, factor*factor*factor2*dim, noise)
output = tf.reshape(output, [-1, factor2*dim, factor, factor])
output = ResidualBlock('Generator.Res1', factor2*dim, factor2*dim, 3, output, is_training=is_training, resample='up')
output = ResidualBlock('Generator.Res2', factor2*dim, factor*dim, 3, output, is_training=is_training, resample='up')
output = ResidualBlock('Generator.Res3', factor*dim, factor*dim/2, 3, output, is_training=is_training, resample='up')
output = ResidualBlock('Generator.Res4', factor*dim/2, factor*dim/4, 3, output, is_training=is_training, resample='up')
if is_training is not None:
output = my_Normalize('Generator.OutputN', output, is_training)
else:
output = Normalize('Generator.OutputN', [0,2,3], output)
output = tf.nn.relu(output)
output = lib.ops.conv2d.Conv2D('Generator.Output', int(factor*dim/4), 1, 3, output)
output = tf.tanh(output)
if z_out:
return tf.reshape(output, [-1, OUTPUT_DIM]), noise
else:
return tf.reshape(output, [-1, OUTPUT_DIM])
def GoodDiscriminator(inputs, dim=DIM, is_training=False, reuse=None, out_feats=True):
with tf.variable_scope('Discriminator', reuse=reuse):
factor2 = int(IMAGE_DIM/8)
factor = int(factor2/2)
#factor = 4
#factor2 = 8
output = tf.reshape(inputs, [-1, 1, IMAGE_DIM, IMAGE_DIM])
output = lib.ops.conv2d.Conv2D('Discriminator.Input', 1, int(factor*dim/4), 3, output, he_init=False)
#output = ResidualBlock('Discriminator.Res1', factor*dim/8, factor*dim/4, 3, output, is_training=is_training, resample='down')
#output = ResidualBlock('Discriminator.Res2', factor*dim/4, factor*dim/2, 3, output, is_training=is_training, resample='down')
#output = ResidualBlock('Discriminator.Res3', factor*dim/2, factor*dim, 3, output, is_training=is_training, resample='down')
#output = ResidualBlock('Discriminator.Res4', factor*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
#output = ResidualBlock('Discriminator.Res5', factor2*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Discriminator.Res1', factor*dim/4, factor*dim/2, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Discriminator.Res2', factor*dim/2, factor*dim, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Discriminator.Res3', factor*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Discriminator.Res4', factor2*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
output = tf.reshape(output, [-1, factor*factor*factor2*dim])
out_features = output
output = lib.ops.linear.Linear('Discriminator.Output', factor*factor*factor2*dim, 1, output)
if out_feats:
return tf.reshape(output, [-1]), out_features
else:
return tf.reshape(output, [-1])
(Haven't checked if the Encoder methods are the same across anomaly_detection.py and z_encoding.izif.py)
def Encoder(inputs, is_training, dim=DIM, z_dim=ZDIM, rand_sampling='normal', reuse=None, z_reg_type=None, denoise=None):
with tf.variable_scope('Encoder', reuse=reuse):
if denoise is not None:
inputs = tf.nn.dropout(inputs, keep_prob=denoise)
output = tf.reshape(inputs, [-1, 1, IMAGE_DIM, IMAGE_DIM])
factor2 = int(IMAGE_DIM/8)
factor = int(factor2/2)
output = lib.ops.conv2d.Conv2D('Encoder.Input', 1, int(factor*dim/4), 3, output, he_init=False)
output = ResidualBlock('Encoder.Res1', factor*dim/4, factor*dim/2, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Encoder.Res2', factor*dim/2, factor*dim, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Encoder.Res3', factor*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
output = ResidualBlock('Encoder.Res4', factor2*dim, factor2*dim, 3, output, is_training=is_training, resample='down')
output = tf.reshape(output, [-1, factor*factor*factor2*dim])
output = lib.ops.linear.Linear('Encoder.Output', factor*factor*factor2*dim, z_dim, output)
if z_reg_type is None:
return output
elif z_reg_type == 'tanh_fc':
return tf.nn.tanh( output )
elif z_reg_type == '3s_tanh_fc':
return tf.nn.tanh( output ) * 3
elif z_reg_type == '05s_tanh_fc':
return tf.nn.tanh( output ) * 0.5
elif z_reg_type == 'hard_clip':
return tf.clip_by_value( output, -1., 1. )
elif z_reg_type == '3s_hard_clip':
return tf.clip_by_value( output, -3., 3. )
elif z_reg_type == '05s_hard_clip':
return tf.clip_by_value( output, -0.5, 0.5 )
elif z_reg_type == 'stoch_clip': ## IMPLEMENTS STOCHASTIC CLIPPING -->> https://arxiv.org/pdf/1702.04782.pdf
if rand_sampling == 'unif':
condition = tf.greater(tf.abs(output), 1.)
true_case = tf.random_uniform(output.get_shape(), minval=-1., maxval=1.)
elif rand_sampling == 'normal':
condition = tf.greater(tf.abs(output), 3.)
true_case = tf.random_normal(output.get_shape())
print(bcolors.YELLOW + "\nImplementing STOCH-CLIP with NORMAL z-mapping!\n" + bcolors.ENDC)
return tf.where(condition, true_case, output)
I have images 800 X 800 resizing in to 64 X 64 will lead to loss of lots of info so training on large resolution of image we have to change the whole network correct me if i am wrong.