shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Self-Attention Generative Adversarial Networks #144

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:1805.08318] Self-Attention Generative Adversarial Networks

著者・所属機関

Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena

投稿日時(YYYY-MM-DD)

2018-05-21

1. どんなもの?

2. 先行研究と比べてどこがすごいの?

従来のGANではテクスチャ情報を再現することはできているが、大域的な構造を再現することができない。このことは犬の画像を生成する際に、犬の体表面などはキレイに生成できるが、足などをうまく生成することができないことからもわかる。

これは畳み込み層は非常に小さな受容野しか有しておらず、大域的な特徴量を補足するには畳み込み層をたくさん積み重ねる必要があるためである。

そこで本研究では大域的な依存関係の補足と計算効率のバランスをとることが可能な自己注意機構を採用することで、画像の大域的な特徴量を活用しながら画像を生成することに成功し、またGANを学習させる際にいくつかのテクニックを採用することで、画像生成の評価指標であるFIDとISで当時のSOTAを達成した。

image

3. 技術や手法の"キモ"はどこにある?

3.1 Self-Attention Generative Adversarial Networks

本研究では特徴量マップを入力に、以下の自己注意計算を行うことで大域的な特徴量同士の関係を重みという形で計算している。

image

まずは画像を2つの線形結合層にかけて、互いの行列積を計算することで、特徴マップ上のある領域iとある領域jとの類似度を計算している。

image

あとは得られた重みのマップを使用して、上記2つとは異なる線形結合層を適用した特徴マップに重み付けを行っている。この際にメモリ効率を上げるために、チャンネル数をk(=1,2,4,8)分の1に減少させている。

image

最後に重み付けを行った特徴マップと、元の入力である特徴マップ同士をパラメータを掛け算して和を計算している。このγを最初は0から始めることで、ネットワークは学習初期段階には局所的な特徴量を学習し、学習が進むにつれて大域的な特徴量も学習することが期待される。

image

3.2 Spectral normalization for both generator and discriminator

本研究では上記モジュールの採用以外にも、学習時に2つのテクニックを使用している。

1つ目はSpectral Normalization (SN) をGeneratorにも適用することである。元の論文ではりぷしっつ制約を満たすようにDiscriminatorに対してのみSNを適用していた。

近年の研究ではGeneratorでもクラスのようなラベル情報を使用することで性能が向上することが知られており、GeneratorにSNを適用することでパラメータの急激な変化などを防ぐことが期待される。

実際に実験ではGeneratorにもDiscriminatorにもSNを適用することで、Discriminatorに対してより少ない更新回数で従来の性能に達成できることを発見した。

3.3 Imbalanced learning rate for generator and discriminator updates

2つ目はGeneratorとDiscriminatorで異なる学習率を採用すること (TTUR) である。

Discriminatorに対して正則化手法を適用すると学習が遅くなってしまうことが知られている。そこで既存の研究ではDiscriminatorに異なる学習率を設定することでDiscriminatorの更新回数を減らすことが可能であることが判明しており、本研究もTTURを使用することで同じ学習時間でより良い精度に達成している。

4. どうやって有効だと検証した?

すべての実験でDiscriminatorの学習率を0.0004、Generatorの学習率を0.0001に設定している。

まずは本研究で使用している学習を安定化させるための2つの手法の影響を調査する。結果からみてわかるようにベースラインと比較すると、SNをGにもDにも適用することでおおはばに評価指標の値が改善されており、TTURを使用することでより安定した学習を行えていることがわかる。

image

次に自己注意機構を適用する層を変化させることでどのように性能が変化するのかを調査した。自己注意機構を採用していない場合と比較すると、DCGANに追加する場合にはより大きな解像度の特徴マップに適用することで評価指標が改善されていることがわかる。

しかしResidual構造を採用したモデルでは、16x16の解像度の特徴マップに適用した場合に最も高い性能を発揮しており、単純にモデルの深い層に適用すればいいだけではないことがわかる。

image

Generatorの最終層から得られた注意重みを可視化すると以下になる。画像の3つの点とどの領域が最も関連しあっているのかを見てみると、空間的に隣り合っていることよりも、離れた領域であっても似たようなテクスチャ情報を有している特徴量同士が互いに関連付けられている。

image

5. 議論はあるか?

shimopino commented 4 years ago

https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py

class SelfAttention(nn.Module):
    r"""Self-Attention Module for calculating long-dependancy on feature map

    Args:
        in_channels (int): the dimension size of the input feature map.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.query = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.in_channels // 8,
                               kernel_size=1)
        self.key   = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.in_channels // 8,
                               kernel_size=1)
        self.value = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.in_channels,
                               kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
shimopino commented 4 years ago
    def forward(self, x):
        r"""
        Args:
            x (tensor): input feature map [B, C, H, W], N = (H*W)

        Returns:
            out (tensor): output feature map after self-attention apply 
        """
        B, C, H, W = x.shape
        query = self.query(x).view(B, -1, H*W).permute(0, 2, 1) # [B, N, C]
        key   = self.key(  x).view(B, -1, H*W)                  # [B, C, N]
        energy = torch.bmm(query, key)                          # [B, N, N]
        attention = self.softmax(energy)                        # [B, N, N]

        value = self.value(x).view(B, -1, H*W)                  # [B, C, N]
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        out = self.gamma * out + x

        return out