shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Weight Standardization #117

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:1903.10520] Weight Standardization

著者・所属機関

Siyuan Qiao, Huiyu Wang, Chenxi Liu, Wei Shen, Alan Yuille

投稿日時(YYYY-MM-DD)

2019-05-25

1. どんなもの?

ミニバッチ内の画像枚数が1枚や2枚といった極端な状況でも効果を発揮する正規化手法であるWeight Normalizationを提案した。GPUについき1画像を使用する設定ではWSはより大きなバッチサイズを使用したBNよりも高い精度を達成した。

WSでは畳み込み層のパラメータを正規化することで、損失や勾配のリプシッツ定数を下げる、つまりより勾配や損失がなめらかになることを示した。

image

2. 先行研究と比べてどこがすごいの?

BNはバッチサイズを大きくすることで効果を発揮する。BNに対抗してバッチサイズに依存しない正規化手法 (Group Normalization) なども考案されているがBNの精度に達成することができていない。

BNは内部共変量シフトを解消しているという説明がなされているが、先行研究[S. Santurkar. et al, 2018]によればBNは内部共変量シフトの解消ではなく、最適化手法によって得られる勾配をなだらかにすることで学習を促進することが示されている。

本研究ではこの考え方に則って勾配や損失をなだらかにする正規化手法を提案した。既存手法が特徴量の正規化に着目していたのに対し、提案されているWeight Normalizationでは学習するパラメータ自体を正規化させている。

3. 技術や手法の"キモ"はどこにある?

BNは特徴量のリプシッツ定数を小さくするな働きをしているが、本研究では最適化を行っているパラメータ自体のリプシッツ定数を考慮していない点に着目して、畳み込み層の重みを正規化する手法を考案した。

image

3.1 Standard BN

通常の畳み込み演算では入力される特徴量の全チャンネルに対して、出力チャンネル数分のKernel Filterを適用し出力される各チャンネルを計算している。

image

Weight Normalizationでは計算に使用する重みパラメータ (PyTorchなら[Cin, Cout, KH, KW]) の平均と標準偏差を計算し正規化を行う。

image image

計算からわかるようにBNやGNで行っているアフィン変換は行っておらず、これは実験でも示されるが正規化させた畳み込み層に対して再度アフィン変換がなされることで精度が悪化するためである。

3.2 勾配

4. どうやって有効だと検証した?

ImageNetを使用して各正規化手法との精度の比較した。なおBN以外はバッチサイズを1に限定している (BNはバッチサイズはGPUあたりに64) 。結果としてはGN+WSが大きなバッチサイズで計算したBNとほど同等の性能を発揮していることがわかる。またアフィン変換 (AF) により精度の悪化も確認できる。

image

またベースラインにResNet-50と101を選択し、いくつかの正規化手法とバッチサイズを変化させながら精度の比較を行った。この比較でもWSの有効性が確認できる。

image

各正規化手法での学習曲線も比較を行った。ハイパラ調整の詳細は論文を参照。

image

COCOデータセットを使用して物体検知やセグメンテーションの結果でもWSの有効性は確認できる。

image

5. 議論はあるか?

shimopino commented 4 years ago

https://github.com/joe-siyuan-qiao/pytorch-classification/tree/e6355f829e85ac05a71b8889f4fff77b9ab95d0b

shimopino commented 4 years ago

公式実装より

class Conv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)

    def forward(self, x):
        # return super(Conv2d, self).forward(x)

        # [Cout, Cin, K, K]
        weight = self.weight
        # Normalize [Cin, K, K] => [Cout, 1, 1, 1]
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        # [Cout, Cin, K, K]
        weight = weight - weight_mean
        # std(dim=1) => [Cout, CinxKxK] => [Cout, 1, 1, 1]
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        # std.shape [Cout, 1, 1, 1] => [Cout, Cin, K, K]
        weight = weight / std.expand_as(weight)
        # use this weight
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

def BatchNorm2d(num_features):
    # channels group 32 is default
    return nn.GroupNorm(num_channels=num_features, num_groups=32)
shimopino commented 4 years ago

各種正規化手法で計算するモーメントの対象

image

Weight Normalizationで正規化を行う対象は以下になる。

image

shimopino commented 4 years ago