shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

A U-Net Based Discriminator for Generative Adversarial Networks #17

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:2002.12655] A U-Net Based Discriminator for Generative Adversarial Networks

著者・所属機関

Edgar Schönfeld, Bernt Schiele, Anna Khoreva

投稿日時(YYYY-MM-DD)

2020-02-28

1. どんなもの?

U-Netベースの構造を採用したDiscriminatorを提案した。DiscriminatorからGeneratorにper-pixelな情報とglobalな情報を送ることで、全体の一貫性と精緻な情報を保った画像を生成することに成功した。また新たにper-pixelな情報を利用した正則化項を導入した。FFHQやCelebAなどのデータセットで、BigGANを上回るFIDを達成した。

2. 先行研究と比べてどこがすごいの?

通常のGANでは、Discriminatorは生成画像と実画像に対するClassifierとして動作するため、Generatorに送られるフェードバック情報とはDiscriminatorが出力する分類確率への誤差に過ぎない。

このため生成画像の分布がGeneratorの学習が進むごとにシフトする、つまり以前生成した画像を保つことが保証されず、非連続的な画像を生成してしまう。

image

上段がU-Net GANが固定された潜在変数から生成した画像を表す。 下段が生成画像に対して、Discriminatorがピクセルごとに本物か偽物か判定した結果を表す。 初期段階の画像の頭頂部といった領域にDiscriminatorが正しく判定を行っており、学習が進むごとにGeneratorが、その領域を正しく生成できるようになっていることがわかる。

3. 技術や手法の"キモ"はどこにある?

3.1 Network Architecture

image

通常のDiscriminatorと同様に、入力画像に対して本物か偽物なのかを判定した後に、Decoderを導入し、UpSampleした画像に対して同様に本物か偽物なのかを判定する処理を加えている。

Skip-Connectionを導入しているため、出力層のチャンネルには画像全体の情報と局所的な情報が含まれている。

3.2 Loss Function

採用する損失関数はシンプルで、通常の画像に対する分類確率に対する損失と、ピクセル単位での分類確率に対する損失を組み合わせている。

Discriminator全体

image

Encoder側のDiscriminator

image

Decoder側のDiscriminator

image

Generator

image

3.3 Consistency Regularization

Decoder側の損失関数にCutMixを利用した正則化項を導入した。

image

CutMixを使用して、実画像と偽物画像とを混ぜたマスク画像を生成し、新たな合成画像を生成します。合成画像に関しては、以下の数式をもとに生成します。

image

そしてこのマスク画像を正解値として損失を計算します。

右辺第1項で計算しているものは、CutMixで生成された合成画像の各ピクセルに対して偽物画像か実画像なのかを示す分類確率です。また右辺第2項は、実画像と偽物画像に対してそれぞれピクセルごとに出力した分類確率をCutMixで合成したものです。

image

4. どうやって有効だと検証した?

FFHQとCOCO-AnimalsのデータセットでBigGANより優れたFIDとISスコアを達成している。

image

Ablation Studyを見ると、CutMixを応用した正則化項の導入が効果を発揮していることがわかる。

5. 議論はあるか?

Encoder側の分類確率とDecoder側の分類確率を軸に生成された画像を見てみると、Encoder側の精度が高い場合はより全体の一貫性を保った画像(紫)を生成しており、Decoder側の精度が高い場合はより細部の情報(オレンジ)が保たれていることがわかる。

image

この2つのNetworkは互いを補いあう関係であり、通常のGANよりもより詳細な情報をGeneratorに送っていることがわかる。

6. 次に読むべき論文はあるか?

論文情報・リンク

shimopino commented 4 years ago

とてもシンプルな手法でBigGANよりFIDが優れているのはすごい。

shimopino commented 4 years ago

CIFAR10でテストできるようにSAGANの構造をベースに以下で計算 non-local BlockはSelf-Attentionを意味する

Generator

model image size
Linear(128, 4x4x256) z[B, 128] --> h[B, 4x4x256]
ResBlock(256, 256) up h[B, 256, 4, 4] --> h[B, 256, 8, 8]
ResBlock(256, 256) up h[B, 256, 8, 8] --> h[B, 256, 16, 16]
non-Local Block h[B, 256, 16, 16] --> h[B, 256, 16, 16]
ResBlock(256, 256) up h[B, 256, 16, 16] --> h[B, 256, 32, 32]
Conv(256, 3) h[B, 256, 32, 32] --> h[B, 3, 32, 32]

だいだいの容量

params memory
4,360,000 16.60MB
shimopino commented 4 years ago

Discriminator Encoder

model image size
ResBlock(3, 128) down h[B, 3, 32, 32] --> h[B, 3, 16, 16]
ResBlock(128, 128) down h[B, 3, 16, 16] --> h[B, 128, 8, 8]
ResBlock(128, 128) h[B, 128, 8, 8] --> h[B, 128, 8, 8]
non-Local Block h[B, 128, 8, 8] --> h[B, 128, 8, 8]
ResBlock(128, 128) h[B, 128, 8, 8] --> h[B, 128, 8, 8]
Global Average Pooling h[B, 128, 8, 8] --> h[B, 128]
Linear(128, 1) h[B, 128] --> h[B, 1]

だいたいの容量

params memory
1,070,000 4.00MB
shimopino commented 4 years ago

Discriminator Decoder

このモデルはEncoderの最後のResBlockの出力 (h[B, 128, 8, 8]) を使用する

model image size
ResBlock(128, 128) h[B, 128, 8, 8] --> h[B, 128, 8, 8]
ResBlock(128, 128) h[B, 128, 8, 8] --> h[B, 128, 8, 8]
ResBlock(128, 128) up h[B, 128, 8, 8] --> h[B, 128, 16, 16]
ResBlock(128, 128) up h[B, 128, 16, 16] --> h[B, 128, 32, 32]
Conv(128, 1) h[B, 128, 32, 32] --> h[B, 1, 32, 32]
Sigmoid

だいたいの容量

params memory
1,220,000 4.60MB
shimopino commented 4 years ago

32x32の場合の全体像

image

farhodfm commented 4 years ago

Hi @KeisukeShimokawa!

Did you implement this method by yourself? I am asking because I couldn't find any code release for this paper. If so, could you share the code? It would be really helpful.

Thanks in advance

shimopino commented 4 years ago

Hi @KeisukeShimokawa!

Did you implement this method by yourself? I am asking because I couldn't find any code release for this paper. If so, could you share the code? It would be really helpful.

Thanks in advance

I am in the process of implementing this method. It is a trial-and-error process because there are parts of the paper that are unclear within the paper alone, such as learning methods.

farhodfm commented 4 years ago

@KeisukeShimokawa, thank you for your reply.

I see. Could you share the information on which parts of the paper are unclear? Maybe I missed those points that are important to understand the paper.

Thank you once again

shimopino commented 4 years ago

@KeisukeShimokawa, thank you for your reply.

I see. Could you share the information on which parts of the paper are unclear?

Thanks.

I created the pseudo code of UNet-based Discriminator including both CutMix Augmentation and Consistency Regularization. Do you think my understanding is correct?

Algorithm

In this algorithm, we have to select the optimal epoch for linearly increasing cutmix probability like initial 200 epoch that the original paper selected.

tatatawab commented 4 years ago

非常に分かりやすいシェアありがとうございます。

75741066-c5768d00-5d4c-11ea-9fa6-9ecce5406d98

上記のロスについて、ピクセル単位での分類確率の総和(平均)をとることに関してはどう解釈されていますか? 結局のところ、U-net Dの出力map全体の代表値を算出してしまうことに違和感を感じています。 代表値をとってしまうと、同じロス値に収束するような組み合わせ(U-net Dの出力map)はいっぱいありそうです。 その場合、ピクセルごとの判定結果を正しくGにフィードバックできているのかな、と。

https://github.com/KeisukeShimokawa/papers-challenge/blob/4640c006fdbb29a90841a911897c1b4f7a1d92e3/src/gan/UNetDiscriminator/models/unetgan_32.py#L296-L310

shimopino commented 4 years ago

@tatatawab さん

式(4)では、Discriminatorが出力する分類確率に対する交差エントロピーでの損失値を計算しています。ですので、分類確率ではなく各ピクセルに対する損失の総和を計算しています。

ご提示していただいている実装部分の通り、reduction="mean"でバッチ内の各画像のピクセル単位の損失値の平均値を計算しています。

なおbinary_cross_entropy_with_logitsの実装は、以下の公式コードを参考にしてください。

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp#L202-L218

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp#L19-L28