Open shimopino opened 4 years ago
まずはDataLoaderから取り出した画像を、90度・180度・270度回転させてデータを増強する。
for i, data in enumerate(dataloader):
# get real/fake image data
real = data[0]
fake = netG.sample(batch_size)
# rotate real image, then concatenate them
real_90 = real.transpose(3, 2)
real_180 = real.flip(2, 3)
real_270 = real.transpose(2, 3).flip(2, 3)
real = torch.cat((real, real_90, real_180, real_270), dim=0)
# rotate fake image, then concatenate them
fake_90 = fake.transpose(3, 2)
fake_180 = fake.flip(2, 3)
fake_270 = fake.transpose(2, 3).flip(2, 3)
fake = torch.cat((fake, fake_90, fake_180, fake_270), dim=0)
あとは回転させたそれぞれの画像群に対してラベルを用意する。これは単純に0~3までのラベルを用意するだけになる。
rot_labels = torch.zeros((batch_size*4, )).to(device)
# prepare four labels for each rotated image
for i in range(4*batch_size):
if i < batch_size:
rot_labels[i] = 0
elif i < 2*batch_size:
rot_labels[i] = 1
elif i < 3*batch_size:
rot_labels[i] = 2
else:
rot_labels[i] = 3
次はDiscriminatorの予測したラベルとのクロスエントロピーを計算すればいいだけである。実際にDiscriminatorからそれぞれのラベルに属する確率が与えられた場合には以下の計算を行う。
criterion = torch.nn.BCELoss()
output, rot_probs = netD(real)
rot_loss = criterion(rot_probs, rot_labels)
提案手法の元
論文へのリンク
[arXiv:1811.11212] Self-Supervised GANs via Auxiliary Rotation Loss
著者・所属機関
Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby
投稿日時(YYYY-MM-DD)
2018-11-27
1. どんなもの?
2. 先行研究と比べてどこがすごいの?
GANの学習が不安定になってしまう原因は非定常な環境での学習を行う必要があるためである。オンライン学習の設定ではNeural Networkは一度学習したサンプルに対する決定境界を、学習が進むにつれて忘れてしまい、以前識別できていたサンプルに対して正しい識別を行うことができなくなってしまう。
GANの学習もGeneratorが生成するデータが常に変化してしまう非定常オンライン学習だと考えることができる。Discriminatorが非定常な環境で学習することで以前Generatorが生成したデータに対する決定境界を忘れてしまうため、学習が不安定になってしまう。
学習を安定化させる方法の1つはラベルなどの教師情報を活用することである。そこで本研究では、ラベル情報が存在しない場合でも自己教師あり学習を行うことで、教師情報を使用した場合の性能に近いGANを学習させることに成功した。
手法はシンプルであり、学習時に画像を回転させDiscriminatorにその回転角度を識別させることで、疑似的に教師情報を活用している場合と同じ効果を持たせている。
本手法により、Discriminatorが過去の学習したサンプルに対して異なる決定境界を引いてしまう現象を軽減することが可能となる。実際に各iterationでの線形分類機の精度を確認するとその効果がわかる。
3. 技術や手法の"キモ"はどこにある?
3.1 The Self-Supervised GAN
Generatorが生成している画像の室に依存しない、有用なデータの表現を学習させることを目指していく。既存の研究で提案されている自己教師あり学習では、画像を回転させたときの角度や画像のパッチ領域の座標なども追加タスクとして学習させている。
本研究ではDiscriminatorに追加のタスクを解かせることで自己教師あり学習を進めていく。具体的には画像を{0, 90, 180, 270}度ずつ回転させ、Discriminatorに回転させた角度を予測させることで追加タスクを解き、回転を検知するために有用な表現をDiscriminatorに転移させることを目指す。
数式としては以下のように表現される。
この追加タスクを行うことで、同じ数のiterationだけ学習させた場合の画像の分類精度が向上していることがわかる。
3.2 Collaborative Adversarial Training
Generatorの学習時には、Generatorは回転のかかっていない画像を生成し、Discriminatorは生成された画像に回転をかけて画像を入力され、その画像の回転角度と本物か偽物かを識別できるように学習を行う。
反対にDiscriminatorの学習時には、実画像に対してのみ回転をかけ、実画像に対してのみパラメータを更新させる。これでGeneratorが回転を検知しやすいような画像を生成してしまうことを防いでいる。
4. どうやって有効だと検証した?
実装ではUnCond-GANにはSNGANを採用し、Cond-GANではProjection-Discriminatorを採用している。本研究で提案しているSS-GANには、ラベル情報を必要としないself-modulated batch normalizationを採用している。
16個の画像に4種類の回転を加えて合計で64個のサンプルをミニバッチに含めている。
CIFAR10とImageNetで本手法の効果を検証するために、各iterationでのFIDを比較した。sBNを使用したSS-GANでは条件付きGANモデルに匹敵する性能を発揮していることがわかる。
他のデータセットでのFIDを比較しても、条件付きモデル同程度の性能を発揮できていることがわかる。
GANはハイパーパラメータの変化に敏感に反応してしまうため、各パラメータを変化させたときのFIDを比較した。比較結果を見てみると、本手法によりパラメータごとのFIDの変化が軽減されており、どのパラメータでも近しい性能を発揮できていることがわかる。
もしもGANが有用な特徴量を捉えることができている場合、Discriminatorの中間層から得られた表現にロジスティック分類を適用すれば高い分類精度になるはずである。そこで各手法との精度の比較を行ったところ、本手法が最も高い精度を発揮している。
この傾向は以下のようにImageNetデータセットに対しても見られる。
5. 議論はあるか?