shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

On Feature Normalization and Data Augmentation #51

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2002.11102] On Feature Normalization and Data Augmentation

著者・所属機関

Boyi Li, Felix Wu, Ser-Nam Lim, Serge Belongie, Kilian Q. Weinberger

投稿日時(YYYY-MM-DD)

2020-02-25

1. どんなもの?

本論文では各サンプルから抽出した平均と標準偏差を使用して、他のサンプルを正規化させる新たなデータ増強手法を提案した。本手法は特徴量空間に適用できるため、入力されるデータに適用するCropやCutMixなどの手法と組み合わせて使用できる。画像・音声・自然言語のタスクで精度の向上を達成。

image

2. 先行研究と比べてどこがすごいの?

近年のデータ増強手法には2種類あり、画像タスクでのFlipやCropなどのラベル情報はそのままにしラベルに紐付くデータを変換する手法と、MixUpなどのラベル情報も変換する手法である。

本論文では、正規化手法にデータ増強を導入した。

3. 技術や手法の"キモ"はどこにある?

本手法は特徴量の正規化とデータの増強を組み合わせたものである。具体的には、2組の学習サンプルに対して、片方のサンプルから得られた1次モーメント(=平均)と2次モーメント(=標準偏差)を使ってもう片方のサンプルを正規化する。

ここではある1つのサンプルに対する正規化関数が可逆であることを仮定している。

image

正規化関数が可逆である場合、順伝播によって得られた特徴量から平均と標準偏差を遡って計算することができる。またこの計算はあくまでも単一のサンプルに対して行っているため、従来のバッチ正則化のような複数のサンプルをまたがって統計量を計算する手法とも組み合わせることができる。

Positional Normalizationを使用する場合には以下の数式で表現できる。

image

バッチ正則化では、ミニバッチにまたがって平均や標準偏差といったモーメントを計算するため、それぞれのサンプルに紐づくラベルの情報が消えてしまう。しかし本手法では、あくまで1つのサンプルの平均と標準偏差を計算するため、サンプルに紐づくラベルの情報がこれらの統計量に含まれている。

著者らは正規化された特徴量と各サンプルから得られたモーメントを同じサンプルだとみなしている。片方のサンプルからモーメント情報を使用して正規化することで、過剰適合を防いでいる。

image

この逆関数はPositional Normalizationを採用する場合には、以下の数式で表現できる。

image

これで画像B(例えば飛行機)のモーメント情報を、画像A(例えば猫)の特徴量に持たせることが可能となる。イメージとしては本来異なる分布を有している飛行機と猫の特徴量に対して、片方の分布を正規化し、一部の分布をもう片方の分布に近づけることで、滑らかな決定境界が引けるようにしている。

それぞれの特徴量を混ぜた場合には、ターゲットのラベル情報にも互いのラベルを混ぜる。

image

実装を試した結果としては、MoExを適用させる場所によって精度に違いが出ている。

4. どうやって有効だと検証した?

まずは様々なモデルに対して、MoExを適用し効果を発揮するのか検証した。どのモデル度のタスクにおいても精度が向上している。

image

次に既存のデータ増強手法と組み合わせて使用することで、精度が向上するのか検証した。最新のどのデータ増強手法に対しても高い性能を発揮していることがわかる。

image

次にImageNetデータセットを使用し、モデルに対してMoExを適用することで精度がどのように変化するのか検証した。ImageNetでも一貫して高い精度を達成できていることがわかる。

image

次に音声タスクに適用しどの程度の効果を発揮するのか検証した。モデルサイズが大きい場合には効果を発揮しているが、使用したデータセットにおいては小さいモデルでは効果を発揮していないことがわかる。

image

次に3Dモデルの分類タスクにMoExを適用し、どの程度の効果を発揮するのか検証した。両方のタスクでより高い精度を達成できていることがわかる。

image

次に機械翻訳タスクにMoExを適用し、どの程度の効果を発揮するのか検証した。本タスクにおいても高い性能を発揮しており、MoExが、ラベルを変動させるデータ増強手法のうち、機械翻訳タスクで初めて効果を発揮した手法である。

image

本手法の頑強性を検証するため、複雑なデータセットであるImageNet-Aにおいて様々なデータ増強手法との比較を行った。MoExはどの手法に対しても高い効果を発揮しているが、最も高い効果を発揮しているのはCutMix等の他のデータ増強手法と組み合わせた場合である。

image

5. 議論はあるか?

6. 次に読むべき論文はあるか?

論文情報・リンク

shimopino commented 4 years ago

https://github.com/Boyiliee/MoEx

shimopino commented 4 years ago
# x: a batch of features of shape (batch_size,
# channels, height, width),
# y: onehot labels of shape (batch_size, n_classes)
# norm_type: type of the normalization to use
def moex(x, y, norm_type):
    x, mean, std = normalization(x, norm_type)
    ex_index = torch.randperm(x.shape[0])
    x = x * std[ex_index] + mean[ex_index]
    y_b = y[ex_index]
    return x, y, y_b

# output: model output
# y: original labels
# y_b: labels of moments
# loss_func: loss function used originally
# lam: interpolation weight $\lambda$
def interpolate_loss(output, y, y_b, loss_func, lam):
    return lam * loss_func(output, y) + \
        (1. - lam) * loss_func(output, y_b)

def normalization(x, norm_type, epsilon=1e-5):
    # decide how to compute the moments
    if norm_type == ’pono’:
    norm_dims = [1]
    elif norm_type == ’instance_norm’:
    norm_dims = [2, 3]
    else: # layer norm
    norm_dims = [1, 2, 3]
    # compute the moments
    mean = x.mean(dim=norm_dims, keepdim=True)
    var = x.var(dim=norm_dims, keepdim=True)
    std = (var + epsilon).sqrt()
    # normalize the features, i.e., remove the moments
    x = (x - mean) / std
    return x, mean, std