Open shimopino opened 4 years ago
適用するデータ増強
s = 1
// 以下の引数は(brightness, contrast, saturation, hue)
color_jitter = torchvision.transforms.ColorJitter(
0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
)
self.train_transform = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(size=size),
torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
torchvision.transforms.RandomApply([color_jitter], p=0.8),
torchvision.transforms.RandomGrayscale(p=0.2),
torchvision.transforms.ToTensor(),
]
)
self.test_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size=size),
torchvision.transforms.ToTensor()
]
)
損失関数の計算
まずは負例のペアを生成するためのマスクを作成する
mask = torch.ones((batch_size * 2, batch_size * 2), dtype=bool)
// 対角成分をFalseにする
mask = mask.fill_diagonal_(0)
// 以下でTrueとFalseが互い違いのマスクを作成
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
// z_i, z_j -> (batchsize, dimension)
p1 = torch.cat((z_i, z_j), dim=0)
// 以下で各サンプル同士の類似度を計算できる
// sim -> (batchsize*2, batchsize*2)
sim = torch.nn.CosineSimilarity(dim=2)(p1.unsqueeze(1), p1.unsqueeze(0))
sim = sim / temperature
// バッチサイズを指定することでi_jペアをすべて抽出
// sim_i_j, sim_j_i -> (batchsize)
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)
// positiveペアをすべて結合させる
// positive_samples -> (batchsize * 2, 1)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
self.batch_size * 2, 1
)
// マスクで1つずつずらしたサンプルを取得
negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)
labels = torch.zeros(self.batch_size * 2).to(self.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= 2 * self.batch_size
論文へのリンク
[arXiv:2002.05709] A Simple Framework for Contrastive Learning of Visual Representations
著者・所属機関
Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton
投稿日時(YYYY-MM-DD)
2020-02-13
1. どんなもの?
本研究では(1)データ増強手法が予測タスクで重要であること、(2)表現ベクトルを学習した後に学習パラメータ付きの非線形変換を設けることで学習できる表現の質が向上すること、(3)Contrastive Learningでは大きなバッチサイズと教師あり学習時よりも長い学習時間によりより高い精度を達成できることを示した。
これらの発見を組み合わせたContrastive Learning手法であるSimCLRを提案した。以下のように既存の自己教師あり学習モデルよりも高い精度を発揮している。
2. 先行研究と比べてどこがすごいの?
3. 技術や手法の"キモ"はどこにある?
3.1 Constrastive Learning Framework
本手法の概要は以下の図になる。主に4つの要素から成り立っている。
ミニバッチのサイズをNとした場合、使用するデータの合計は2N個になる。同じデータから生成されたペアを正例とし、異なるデータから生成されたペアを負例として学習を行う。
損失関数には、各特徴ベクトルの類似度を考慮する。
全体のアルゴリズムは以下で表現される。
3.2 Training with Large Batch Size
大きなバッチサイズをSGD/Momemtumを使って学習させ、また線形に学習率を変化させルだけでは学習が不安定になってしまう。そこでLAPS最適化を適用している。
複数のTPUを並列に使用して学習させているため、それぞれのデバイスでバッチ正則化の統計量を計算しないように、各デバイスで平均や分散といった情報は共有させている。
3.3 Evaluation Protocol
評価のためにImageNetとCifar-10のデータセットを使用している。また学習させたモデルを評価する際には、パラメータを固定して最終出力層に線形分類器を導入し評価指標を計算させている。
バッチサイズを4096に設定して100エポック学習させている。また学習率は最初の10エポックでは線形に上昇させ、その後はWarmupを採用している。
3.4 Data Augmentation
従来の研究では複雑な処理を行っていたが、本手法は単純に画像をクロップしてリサイズするだけなので、タスクとしては以下のように画像内の領域に対してGlobal-to-Localな特徴を学習したり、隣り合った特徴を学習する形になる。
4. どうやって有効だと検証した?
4.1 Data Augmentation
データ増強手法を空間構造に変換をかけるものと物体の外観に変換をかけるものに大別し、個別あるいはペアとして画像に適用することでその精度を毛印象した。
複数の組み合わせを比較した結果、最も精度が高い手法はランダムクロップとrランダムな色変化である。
上記の組み合わせが最適である理由としては、以下のようにクロップされた領域に色の変化を加えると、切り取った領域の色の分布が大きく変化してしまい、モデルはこれらを正例だと判断するためより汎化性能が向上するためであると考えられる。
教師あり学習と比較すると適したデータ増強手法は異なっており、最も効果を発揮しているものは色の変化であり、強化学習で効果のあったAutoAugmentなどは効果を発揮していないことがわかる。
4.2 Encoder and Head
教師あり学習でも教師あり学習でもモデルサイズが大きいとより高い精度を発揮している。
次に表現ベクトルに対して行った非線形変換の効果を検証した。 到達できる精度に関しては変換前の表現ベクトルの次元数にかかわらず、一定の精度を発揮している。
変換前の表現ベクトルに関してContrastive Lossによる情報の損失が発生していないかを検証した。非線形変換はデータ増強の手法に関係なく学習できるため、下流タスクに必要な情報が損失されていないか、変換前の表現ベクトルと比較することで確かめた。
結果からは下流タスクに対して表現ベクトルは、非線形変換を行った後のベクトルよりもより多くの情報を有していることがわかる。
4.3 Loss Functions and Batch Size
以下はバッチサイズと学習時間の影響を比較した結果である。学習時間やバッチサイズを増大させることで、より多くのnegative pairを学習することになり、結果が改善すると考えられる。
4.4 Comparison with State-of-the-art
5. 議論はあるか?
6. 次に読むべき論文はあるか?
論文情報・リンク