tamarott / SinGAN

Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"
https://tamarott.github.io/SinGAN.htm
Other
3.31k stars 611 forks source link

Question about choice of loss and normalization techniques #62

Open yikedi opened 4 years ago

yikedi commented 4 years ago

The paper states that WGAN-GP increases the training stability so I am curious what other loss and normalization techniques you have tried.

I am trying to add spectral normalization suggested by this paper https://arxiv.org/abs/1802.05957 to improve the quality of generated images so I want to know if you have used it and if you do could you please briefly compare the results with the current one which uses WGAN-GP?

Thank you

tamarott commented 4 years ago

Thanks. We didn't try to use the spectral normalization.

15732031137 commented 4 years ago

Hello!I also want to try to use the spectrum normalization, but my program is wrong, I have checked a lot of data have not been clear, so I ask for your advice!The errors are as follows:

Traceback (most recent call last): File "SR.py", line 45, in train(opt, Gs, Zs, reals, NoiseAmp) File "E:\SinGAN-masterplus\SinGAN\training.py", line 34, in train D_curr,G_curr = init_models(opt) File "E:\SinGAN-masterplus\SinGAN\training.py", line 310, in init_models netG.apply(models.weights_init) File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply module.apply(fn) File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply module.apply(fn) File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 294, in apply fn(self) File "E:\SinGAN-masterplus\SinGAN\models.py", line 215, in weightsinit m.weight.data.normal(0.0, 0.02) File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 591, in getattr type(self).name, name)) AttributeError: 'Conv2d' object has no attribute 'weight'

Here are the relevant codes that I added to the spectrum normalization:

def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): def init(self, module, name='weight', power_iterations=1): super(SpectralNorm, self).init() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params()

def _update_u_v(self):
    u = getattr(self.module, self.name + "_u")
    v = getattr(self.module, self.name + "_v")
    w = getattr(self.module, self.name + "_bar")

    height = w.data.shape[0]
    for _ in range(self.power_iterations):
        v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
        u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

    # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
    sigma = u.dot(w.view(height, -1).mv(v))
    setattr(self.module, self.name, w / sigma.expand_as(w))

def _made_params(self):
    try:
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")
        return True
    except AttributeError:
        return False

def _make_params(self):
    w = getattr(self.module, self.name)

    height = w.data.shape[0]
    width = w.view(height, -1).data.shape[1]

    u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
    v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
    u.data = l2normalize(u.data)
    v.data = l2normalize(v.data)
    w_bar = Parameter(w.data)

    del self.module._parameters[self.name]

    self.module.register_parameter(self.name + "_u", u)
    self.module.register_parameter(self.name + "_v", v)
    self.module.register_parameter(self.name + "_bar", w_bar)

def forward(self, *args):
    self._update_u_v()
    return self.module.forward(*args)

class ConvBlock(nn.Sequential): def init(self, in_channel, out_channel, ker_size, padd, stride): super(ConvBlock,self).init() self.add_module('conv', nn.Conv2d(in_channel, out_channel, kernel_size=ker_size, stride=stride, padding=padd)),

self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=ker_size, stride=stride, padding=padd)

    # self.add_module('conv',self.conv),
    # self.add_module('norm', SpectralNorm(self.conv)),

    self.add_module('LeakyRelu', nn.LeakyReLU(0.2, inplace=True))

Thank you very much for your reading, wish you a happy life!