junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
22.94k stars 6.31k forks source link

White spots artifacts on the picture #411

Closed yinqk closed 5 years ago

yinqk commented 6 years ago

Hi,I tried to train my own collected dataset on the source code you provided, but the generated image always produced a white spots artifact ,Do you know why? Any suggestions?

vrao9 commented 5 years ago

@yinqk I'm also facing the same problem. How did you solve it?

derekahuang commented 5 years ago

@yinqk @vrao9 I'm facing the same issue when working with greyscale mammography images

vrao9 commented 5 years ago

@derekahuang I observed that for some epochs the corresponding images had large number of such patches, but for some epochs, those were comparatively less. Reducing the learning rate of the discriminator helped in a small way to reduce those artefacts. I think those artefacts occur when the discriminator gets too powerful, i.e., discriminator achieves approx 100% accuracy.

derekahuang commented 5 years ago

@vrao9 thanks for the tips. i'm training between two images that are pretty close, and the artifacts appear after just a few steps...did you solve it with methods besides decreasing the discriminator learning rate?

vrao9 commented 5 years ago

@derekahuang No.. I did not.. if you find a solution, do post it here :) Maybe you could use perceptual loss in addition? Here is a paper which implements it for the unpaired dataset: CartoonGAN: Generative Adversarial Networks for Photo Cartoonization link

tlatlbtle commented 4 years ago

@derekahuang I observed that for some epochs the corresponding images had large number of such patches, but for some epochs, those were comparatively less. Reducing the learning rate of the discriminator helped in a small way to reduce those artefacts. I think those artefacts occur when the discriminator gets too powerful, i.e., discriminator achieves approx 100% accuracy. The same problem, have you solve it?

lrunaways commented 4 years ago

In paper "Analyzing and Improving the Image Quality of StyleGAN" they solve the same problem by changing instance normalization to weight demodulation.

As they say: "We hypothesize that the droplet artifact is a result of the generator intentionally sneaking signal strength information past instance normalization: by creating a strong, localized spike that dominates the statistics, the generator can effectively scale the signal as it likes elsewhere."

mathiasibsen commented 3 years ago

Did any1 ever find a good solution to this? I try to teach the model to remove specific objects from an image, so the input and output image is very similar.

lrunaways commented 3 years ago

Did any1 ever find a good solution to this? I try to teach the model to remove specific objects from an image, so the input and output image is very similar.

Removing instance normalisation and changing all convolutions to "demodulated" convolutions (as said in StyleGANv2 paper in 2.2) helped in my case.

mathiasibsen commented 3 years ago

@lrunaways Thanks a lot for the suggestions. Do you by any chance have some sample code for this?

lrunaways commented 3 years ago

@i-regular, I copied basic keras convolution layer and changed call function to

  def call(self, inputs):
    # Check if the input_shape in call() is different from that in build().
    # If they are different, recreate the _convolution_op to avoid the stateful
    # behavior.
    call_input_shape = inputs.get_shape()
    recreate_conv_op = (
        call_input_shape[1:] != self._build_conv_op_input_shape[1:])

    if recreate_conv_op:
      self._convolution_op = nn_ops.Convolution(
          call_input_shape,
          filter_shape=self.kernel.shape,
          dilation_rate=self.dilation_rate,
          strides=self.strides,
          padding=self._padding_op,
          data_format=self._conv_op_data_format)

    # Demodulation
    weights = self.kernel
    d = K.sqrt(K.sum(K.square(weights), axis=[1, 2, 3], keepdims=True) + 1e-8)
    weights = weights / d
    outputs = self._convolution_op(inputs, weights)

    if self.use_bias:
      if self.data_format == 'channels_first':
        if self.rank == 1:
          # nn.bias_add does not accept a 1D input tensor.
          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
          outputs += bias
        else:
          outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
      else:
        outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
mathiasibsen commented 3 years ago

Thanks a lot, i'll have a look. So, you redid the entire implementation in Keras?

lrunaways commented 3 years ago

I copied convolutional layer into separate py file, changed call function and imported it as usual

mathiasibsen commented 3 years ago

Alright thanks, I got it implemented in PyTorch now and it is currently training. I'll let you know how it works in some days! Thanks again

mathiasibsen commented 3 years ago

I found my problem. I needed to specify the --eval flag during testing as I was training with a batch_size > 1

seawee1 commented 3 years ago

Thank you so much, weight demodulation solved all my problems. In case someone wants to try it out, here's my PyTorch implementation:

class DemodulatedConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=0, bias=False, dilation=1):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channel))

        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, input):
        batch, in_channel, height, width = input.shape

        demod = torch.rsqrt(self.weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = self.weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        input = input.view(1, batch * in_channel, height, width)
        if self.bias is None:
            out = F.conv2d(input, weight, padding=self.padding, groups=batch, dilation=self.dilation, stride=self.stride)
        else:
            out = F.conv2d(input, weight, bias=self.bias, padding=self.padding, groups=batch, dilation=self.dilation, stride=self.stride)
        _, _, height, width = out.shape
        out = out.view(batch, self.out_channel, height, width)

        return out
mrartemev commented 3 years ago

Hi! Not sure that the following line would work correctly:

demod.view(batch, self.out_channel, 1, 1, 1)

mrgloom commented 3 years ago

I wonder how weight demodulation should work in the case when generator don't have adain like additional inputs?

Looking at the stylegan2 code: there is modulated_conv2d_layer Here is modulate https://github.com/NVlabs/stylegan2/blob/6af5afc72dbeb77bb2bd49919a7b8dcfc8ea644d/training/networks_stylegan2.py#L97-L100 Here is demodulate https://github.com/NVlabs/stylegan2/blob/6af5afc72dbeb77bb2bd49919a7b8dcfc8ea644d/training/networks_stylegan2.py#L102-L105

As I understand based on examples above suggestion is to omit modulation part and only use demodulation part, but this way it looks like some kind of weight normalization, I wonder just clipping weights or adding l2 regularisation for weights will have about the same effect?

mrgloom commented 3 years ago

Regarding L2 regularization I can't see visual improvements adding self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=1e-4), but for larger values of weight_decay like 1e-1, it seems breaks things a bit.

Example for weight_decay=1e-1:

1e-1
WayneCho commented 3 years ago

Did any1 ever find a good solution to this? I try to teach the model to remove specific objects from an image, so the input and output image is very similar.

Removing instance normalisation and changing all convolutions to "demodulated" convolutions (as said in StyleGANv2 paper in 2.2) helped in my case.

I face the same issue right now and I try to substitute all the conv2d to demodulated_conv as said in stylegan2, the problem seems fixed however I must train the model at a very small learning rate (i.e. 1e-5) or the model will collapse at the very beginning. So I wonder are you substitute all the conv2d layers (including conv2d in D) to demodulated_conv ? Have you faced the same problem with the learning rate?

WuZongWei6 commented 3 years ago

Thank you so much, weight demodulation solved all my problems. In case someone wants to try it out, here's my PyTorch implementation:

class DemodulatedConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=0, bias=False, dilation=1):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channel))

        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, input):
        batch, in_channel, height, width = input.shape

        demod = torch.rsqrt(self.weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = self.weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        input = input.view(1, batch * in_channel, height, width)
        if self.bias is None:
            out = F.conv2d(input, weight, padding=self.padding, groups=batch, dilation=self.dilation, stride=self.stride)
        else:
            out = F.conv2d(input, weight, bias=self.bias, padding=self.padding, groups=batch, dilation=self.dilation, stride=self.stride)
        _, _, height, width = out.shape
        out = out.view(batch, self.out_channel, height, width)

        return out

Thanks for your sharing! I am very confused that your code is only the part of demodulation, do you need to modulate before this?

AnimationFan commented 1 year ago

According to my experience, you may need to check your model structure, espcially norm and act function. I attemped to use attention block in my gan network , but i foud colorful spot always, i think it relate to the attention block. And I find I forget to add norm and act funtion follow since i do some change on other's code. Hope this experience can help others meet the same problem