shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Large Scale GAN Training for High Fidelity Natural Image Synthesis #106

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:1809.11096] Large Scale GAN Training for High Fidelity Natural Image Synthesis

著者・所属機関

Andrew Brock, Jeff Donahue, Karen Simonyan

投稿日時(YYYY-MM-DD)

2018-09-28

1. どんなもの?

SAGANをベースに近年提案されている様々な手法を適用し、また学習時と推論時に異なる分布からサンプルリングを行うTruncation Trickを採用したところ従来のSOTAを大幅に上回るISとFIDを達成した。

GANの分析方法などかなり参考になる。

2. 先行研究と比べてどこがすごいの?

GANの学習を安定化させるための手法は数多く提案されているが、条件付きのImageNetの画像生成では実データのISが233であるのに対し、GANでは未だに52.5しか達成できていない。

本研究では(1)モデルのサイズとバッチサイズを増大させること、(2)生成画像の質と多様性をトレードオフできるTruncation Trickを採用したこと、(3)既存手法を組み合わせることで実データに近いISと飛躍的なFIDの向上を達成した。

3. 技術や手法の"キモ"はどこにある?

3.1 Scaling Up

baselineにはヒンジ損失を最小しているSAGANを使用する。クラス情報はGeneratorに関してはClass-Conditionalなバッチ正則化を使用し、DiscriminatorにはProjection Discriminatorを採用する。モデルにはOrthogonal初期化を使用する。

以下の図のようにバッチサイズやチャンネル数を増大させることでISとFIDの大きな向上が見られる。

これはバッチサイズを大きくすることで学習データにより多くのモードを含めることができるからと考えられる。しかしバッチサイズを増大させることで学習が不安定になりモード崩壊を招いてしまったためこれは後ほど検証する。

またGeneratorでのConditionalバッチ正則化では、入力となるクラスベクトルcは各層で同じものを使用することで計算コストとメモリコストを抑えて学習時間の改善に成功した。

潜在ベクトルzに関してもGeneratorの入力層にのみ代入するのではなく、中間層にも代入し異なる解像度での特徴量にも反映されるようにしている。BigGANではこの潜在ベクトルzを解像度ごとに分割したあとでクラスベクトルcと結合させて各層に入力しているが、BigGAN-deepでは潜在ベクトルzをそのまま使用している。この設計によって学習時間と精度の改善に成功した。

image

3.2 Truncation Trick

本来ならGANの潜在ベクトルは任意の分布からサンプリングできるが、実際には入力となる潜在ベクトルの分布は一様分布あるいはガウス分布が採用されており、これが最適がどうかは後ほど検証する。

潜在ベクトルzは学習時にガウス分布からサンプリングし、推論時にはtruncatedな(切断)正規分布からサンプリングすることでISとFIDの向上が見られた(切断正規分布のイメージはコメントを参照)。

この打ち切る範囲を変化させることで生成される画像の質と多様性をトレードオフにすることが可能となる。

image

この閾値を変化させることでISとFIDに関してPR曲線のような質と多様性のトレードオフの関係を捉えることができている。クラス条件付きモデルにおいて、ISは生成画像の多様性を評価できないため閾値を低くすれば大きくスコアが改善されるが、FIDは多様性を評価するために閾値をゼロにしてしまうとスコアが多くく減少していることがわかる。

image

この手法の問題点は、学習時と推論時で事前分布が異なることであり、大きなモデルを学習させた際には以下のようにTruncated normalからサンプリングした潜在ベクトルzに対してうまく画像生成ができていない。

image

この問題に対処するには、Generatorを平滑化させてTruncated normalに対して適応できるようにすること、つまり潜在ベクトルzの全領域から良質な画像を生成できることが必要だと考え、Orthogonal正則化を適用した。

image

これは重み行列が直交行列になるように制限をかけることで、何度も行列積を行ったとしてもノルムが変化しない、つまり勾配消失や勾配爆発が抑えられると期待している。

しかしこの式は正則化効果が非常に強いため、制限のゆるい新たな正則化を提案した。これは重み行列の対角要素以外の要素が大きくならないように制限をかけることで、直接直交行列になるような制限よりも精度が向上すると考えられる。

image

4. どうやって有効だと検証した?

上記の手法によりモデルは改善されたが、途中で学習が崩壊してしまう現象も発生しておりEealy Stoppingを行う必要がある。まずはこの原因を調査する。

4.1 Analysis

4.1.1 Generator

まず以下では小さなモデルでは生じなかった不安定性が大きなモデルで発生した原因を見ていく。手法としてはObena(2018)で行われていたように、学習中の重み・勾配・損失を見ていき崩壊との関連性を探った。

結果としてはAlrnoldiの反復法で得られた第3位までの特異値が不安定性に関して優位な情報を有しており、以下の図のようにGeneratorの一部の層(たいてい最初の層)が不安定性を引き起こしていると考えられる。

image

この現象が不安定性の原因なのか単なる症状なのか検証するために、Generatorにスペクトルの過度な増大が発生しないような制限を加えた。

image

この手法をスペクトル正規化と同時に使用した場合でもそうでない場合でも、多少精度の改善や学習の安定化に貢献しているが、完全に不安定性を防いでいるわけではない。

4.1.2 Discriminator

Generatorと同様の分析を行った。Generatorと異なる殆どの層のスペクトルにはノイズが混じっており、不安定性が発生した場合も特異値は爆発しているわけではなく、多少値がジャンプしているのみである。

image

途中途中のスパイクはDiscriminatorが大きな勾配を受け取っていることを示すが、AppendixFによるとフロベニウスノルムは増大していないため、この現象は上位の特異値のみに発生している現象である。

この現象はGeneratorがDiscriminatorを大きく摂動させるような画像を生成してしまう敵対的学習手法によるものだと仮定し、Discriminatorのヤコビアンに対して陽に正則化を適用した。

image

結果としては正則化の効果をγ=10によることで学習は安定化しGeneratorとDiscriminator両方の重み行列のスペクトルは平滑化したが大きく精度は減少してしまい、正則化項をγ=1に設定しても精度が大きく減少し、同様に現象はOrthogonal正則化やDropout、L2正則にも見られた。

またDiscriminatorの損失値が0に近い場合でも不安定性が発生した場合には大きく損失が増大していることも判明した。これはDiscriminatorが学習データを完全に覚えてしまうことで汎化性能が減少してしまっているからと考えられる。

以上を踏まえると、現状では学習を安定化させながら精度を保つ手法は存在せず、学習を安定化させるためにはある程度精度を犠牲にするか、学習をEarly Stoppingなどで監視する必要がある。

4.2 Result

image

各手法を用いた場合の評価指標を比較した。順に(1)Trancation Trackを使用しない、(2)FIDを最小化した(多様性を犠牲)、(3)検証データに対するIS、(4)ISを最大化した場合を比較している。

image

image

5. 議論はあるか?

shimopino commented 4 years ago

切断正規分布とは以下のように正規分布を一部の範囲のみ切り取った分布のことである。

image

shimopino commented 4 years ago

フロベニウスノルムの性質 参考解説

shimopino commented 4 years ago

BigGANのモデル構造は以下になる。潜在ベクトルを20次元程度に分割して、クラスベクトルと結合させて各層に入力している。

image

GeneratorでUpSampleを行う場合の各ResBlockの構造は以下になる。BatchNormに入力する際はバイアスは0センターで投影し、ゲインは1センターで投影を行っている。

image

DiscriminatorでDownSampleを行う場合の各ResBlockの構造は以下になる。入力と出力のチャンネル数は一致するような構造になっていることに注意。

image

shimopino commented 4 years ago

https://github.com/ajbrock/BigGAN-PyTorch https://github.com/sxhxliang/BigGAN-pytorch https://github.com/huggingface/pytorch-pretrained-BigGAN

shimopino commented 4 years ago

Orthogonal Regularization (直行正則化) はモデルのパラメータを更新する直前に適用する。

def ortho(model, strength=1e-4, blacklist=[]):
  with torch.no_grad():
    for param in model.parameters():
      # Only apply this to parameters with at least 2 axes, and not in the blacklist
      if len(param.shape) < 2 or any([param is item for item in blacklist]):
        continue
      w = param.view(param.shape[0], -1)
      grad = (2 * torch.mm(torch.mm(w, w.t())
              * (1. - torch.eye(w.shape[0], device=w.device)), w))
      param.grad.data += strength * grad.view(param.shape)

image

shimopino commented 4 years ago

まずはDiscriminatorのResblockを作っていく。

class DBlock(nn.Module):
    def __init__(self, 
                 in_channels,
                 out_channels,
                 hidden_channels=None,
                 downsample=False,
                 spectral_norm=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels \
                               if hidden_channels is not None else in_channels
        self.downsample = downsample
        self.learnable_sc = (in_channels != out_channels) or downsample
        self.spectral_norm = spectral_norm        

        if self.spectral_norm:
           self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
           self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
        else:
           self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
           self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, 1)

        self.average_pooling = nn.AvgPool2d(2)
        self.activation = nn.ReLU(inplace=True)

        # Shortcut layer
        if self.learnable_sc:
            if self.spectral_norm:
                self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
            else:
                self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        x0 = x

        # shortcut branch
        if self.learnable_sc:
            x0 = self.c_sc(x0)
            if self.downsample:
                x0 = self.average_pooling(x0, 2)

        # normal branch
        x = self.activation(x)
        x = self.c1(x)
        x = self.activation(x)
        x = self.c2(x)
        if self.downsample:
            x = self.average_pooling(x, 2)

        # residual connection
        out = x + x0
        return out
shimopino commented 4 years ago

Generatorに導入するClass-Conditional BatchNormに関しては、論文中の以下の記述に従う。

We use a single shared class embedding in G, which is linearly projected to produce per-sample gains and biases for the BatchNorm layers. The bias projections are zero-centered, while the gain projections are one-centered. When employing hierarchical latent spaces, the latent vector z is split along its channel dimension into equal sized chunks, and each chunk is separately concatenated to the copy of the class embedding passed into a given block.

class ConditionalBatchNorm2d_with_skip_and_shared(nn.Module):
    def __init__(self, num_features, concat_vector_dim, spectral_norm=False):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)

        if spectral_norm:
            self.gain = SNLinear(concat_vector_dim, num_features, bias=False)
            self.bias = SNLinear(concat_vector_dim, num_features, bias=False)
        else:
            self.gain = nn.Linear(concat_vector_dim, num_features, bias=False)
            self.bias = nn.Linear(concat_vector_dim, num_features, bias=False)

    def forward(self, x, concat_vector):
        r"""feed-forward the input feature map and concatenated embedding

        Args:
            x (Tensor): the input feature map of shape [B, C, H, W]
            concat_vector (Tensor): concatenated vector of latent code and class embedding
        """
        gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) # one-centered gain
        bias = self.bias(y).view(y.size(0), -1, 1, 1) # zero-centered bias
        out = self.bn(x)
        return out * gain + bias
shimopino commented 4 years ago

次にGeneratorのResblockに入力される変数が、潜在変数とクラスベクトルが結合されたものを想定して作成していく。

class GBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 hidden_channels=None,
                 upsample=False,
                 concat_vector_dim=None,
                 spectral_norm=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels if hidden_channels is not None else out_channels
        self.learnable_sc = in_channels != out_channels or upsample
        self.upsample = upsample

        self.concat_vector_dim= concat_vector_dim
        self.spectral_norm = spectral_norm

        if self.spectral_norm:
            self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
            self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
        else:
            self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
            self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, 1)

        if self.num_classes == 0:
            self.b1 = nn.BatchNorm2d(self.in_channels)
            self.b2 = nn.BatchNorm2d(self.hidden_channels)
        else:
            self.b1 = ConditionalBatchNorm2d_with_skip_and_shared(
                         self.in_channels,
                         self.concat_vector_dim
                      )
            self.b2 = ConditionalBatchNorm2d_with_skip_and_shared(
                         self.hidden_channels,
                         self.concat_vector_dim
                      )

        self.activation = nn.ReLU(True)

        # Shortcut layer
        if self.learnable_sc:
            if self.spectral_norm:
                self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
            else:
                self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x, concat_vector):

        x0 = x

        # shortcut branch
        if self.upsample:
            x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
        if self.learnable_sc:
            x0 = self.c_sc(x0)

        # normal branch
        ## first block
        if self.concat_vector_dim is not None:
            x = self.b1(x, concat_vector)
        else:
            x = self.b1(x)
        x = self.activation(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
        ## second block
        if self.concat_vector_dim is not None:
            x = self.b2(x, concat_vector)
        else:
            x = self.b2(x)
        x = self.activation(x)
        x = self.c2(x)

        # residual connection
        out = x + x0
        return out
shimopino commented 4 years ago

次にGeneratorを構築していく。なお論文に従って128x128の解像度を前提に作成していく。

image

class Generator(nn.Module):

    def __init__(self, nz=120, ny=128, ngf=64, bottom_width=4, spectral_norm=False, num_classes=0):
        super().__init__()

        self.nz= nz
        self.ny= ny
        self.ngf = ngf
        self.bottom_width = bottom_width
        self.num_classes = num_classes

    def forward(self, z, label):