shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

A Simple Framework for Contrastive Learning of Visual Representations #38

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2002.05709] A Simple Framework for Contrastive Learning of Visual Representations

著者・所属機関

Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton

投稿日時(YYYY-MM-DD)

2020-02-13

1. どんなもの?

本研究では(1)データ増強手法が予測タスクで重要であること、(2)表現ベクトルを学習した後に学習パラメータ付きの非線形変換を設けることで学習できる表現の質が向上すること、(3)Contrastive Learningでは大きなバッチサイズと教師あり学習時よりも長い学習時間によりより高い精度を達成できることを示した。

これらの発見を組み合わせたContrastive Learning手法であるSimCLRを提案した。以下のように既存の自己教師あり学習モデルよりも高い精度を発揮している。

image

2. 先行研究と比べてどこがすごいの?

3. 技術や手法の"キモ"はどこにある?

3.1 Constrastive Learning Framework

本手法の概要は以下の図になる。主に4つの要素から成り立っている。

  1. ある画像に対してデータ増強手法をランダムに適用し、異なるデータ増強を施した2つの画像を取得する
  2. 増強された2つの画像に対してニューラルネットワークを適用することでそれぞれの表現ベクトルを取得する。本研究ではResNetを使用しており、平均プーリング層後のd次元ベクトルを使用する
  3. Contrastive Lossに適用するために2つの表現ベクトルを小さなネットワークで変換を行う。この際にLinear->ReLU->Linearと計算し間に非線形変換を施す
  4. Contrastive Lossを計算し、xiが与えられた際にxjを識別するタスクになる。

image

ミニバッチのサイズをNとした場合、使用するデータの合計は2N個になる。同じデータから生成されたペアを正例とし、異なるデータから生成されたペアを負例として学習を行う。

損失関数には、各特徴ベクトルの類似度を考慮する。

image

全体のアルゴリズムは以下で表現される。

image

3.2 Training with Large Batch Size

大きなバッチサイズをSGD/Momemtumを使って学習させ、また線形に学習率を変化させルだけでは学習が不安定になってしまう。そこでLAPS最適化を適用している。

複数のTPUを並列に使用して学習させているため、それぞれのデバイスでバッチ正則化の統計量を計算しないように、各デバイスで平均や分散といった情報は共有させている。

3.3 Evaluation Protocol

評価のためにImageNetとCifar-10のデータセットを使用している。また学習させたモデルを評価する際には、パラメータを固定して最終出力層に線形分類器を導入し評価指標を計算させている。

バッチサイズを4096に設定して100エポック学習させている。また学習率は最初の10エポックでは線形に上昇させ、その後はWarmupを採用している。

3.4 Data Augmentation

従来の研究では複雑な処理を行っていたが、本手法は単純に画像をクロップしてリサイズするだけなので、タスクとしては以下のように画像内の領域に対してGlobal-to-Localな特徴を学習したり、隣り合った特徴を学習する形になる。

image

4. どうやって有効だと検証した?

4.1 Data Augmentation

データ増強手法を空間構造に変換をかけるものと物体の外観に変換をかけるものに大別し、個別あるいはペアとして画像に適用することでその精度を毛印象した。

image

複数の組み合わせを比較した結果、最も精度が高い手法はランダムクロップとrランダムな色変化である。

image

上記の組み合わせが最適である理由としては、以下のようにクロップされた領域に色の変化を加えると、切り取った領域の色の分布が大きく変化してしまい、モデルはこれらを正例だと判断するためより汎化性能が向上するためであると考えられる。

image

教師あり学習と比較すると適したデータ増強手法は異なっており、最も効果を発揮しているものは色の変化であり、強化学習で効果のあったAutoAugmentなどは効果を発揮していないことがわかる。

image

4.2 Encoder and Head

教師あり学習でも教師あり学習でもモデルサイズが大きいとより高い精度を発揮している。

image

次に表現ベクトルに対して行った非線形変換の効果を検証した。 到達できる精度に関しては変換前の表現ベクトルの次元数にかかわらず、一定の精度を発揮している。

image

変換前の表現ベクトルに関してContrastive Lossによる情報の損失が発生していないかを検証した。非線形変換はデータ増強の手法に関係なく学習できるため、下流タスクに必要な情報が損失されていないか、変換前の表現ベクトルと比較することで確かめた。

結果からは下流タスクに対して表現ベクトルは、非線形変換を行った後のベクトルよりもより多くの情報を有していることがわかる。

image

4.3 Loss Functions and Batch Size

image

以下はバッチサイズと学習時間の影響を比較した結果である。学習時間やバッチサイズを増大させることで、より多くのnegative pairを学習することになり、結果が改善すると考えられる。

image

4.4 Comparison with State-of-the-art

image

image

5. 議論はあるか?

6. 次に読むべき論文はあるか?

論文情報・リンク

shimopino commented 4 years ago

https://github.com/Spijkervet/SimCLR

shimopino commented 4 years ago

適用するデータ増強

s = 1
// 以下の引数は(brightness, contrast, saturation, hue)
color_jitter = torchvision.transforms.ColorJitter(
    0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
)

self.train_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomResizedCrop(size=size),
        torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
        torchvision.transforms.RandomApply([color_jitter], p=0.8),
        torchvision.transforms.RandomGrayscale(p=0.2),
        torchvision.transforms.ToTensor(),
    ]
)

self.test_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(size=size),
        torchvision.transforms.ToTensor()
    ]
)
shimopino commented 4 years ago

損失関数の計算

まずは負例のペアを生成するためのマスクを作成する

mask = torch.ones((batch_size * 2, batch_size * 2), dtype=bool)
// 対角成分をFalseにする
mask = mask.fill_diagonal_(0)
// 以下でTrueとFalseが互い違いのマスクを作成
for i in range(batch_size):
    mask[i, batch_size + i] = 0
    mask[batch_size + i, i] = 0
// z_i, z_j -> (batchsize, dimension)
p1 = torch.cat((z_i, z_j), dim=0)
// 以下で各サンプル同士の類似度を計算できる
// sim -> (batchsize*2, batchsize*2)
sim = torch.nn.CosineSimilarity(dim=2)(p1.unsqueeze(1), p1.unsqueeze(0))
sim = sim / temperature

// バッチサイズを指定することでi_jペアをすべて抽出
// sim_i_j, sim_j_i -> (batchsize)
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)

// positiveペアをすべて結合させる
// positive_samples -> (batchsize * 2, 1)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
    self.batch_size * 2, 1
)

// マスクで1つずつずらしたサンプルを取得
negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)

labels = torch.zeros(self.batch_size * 2).to(self.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= 2 * self.batch_size