Open shimopino opened 4 years ago
切断正規分布とは以下のように正規分布を一部の範囲のみ切り取った分布のことである。
BigGANのモデル構造は以下になる。潜在ベクトルを20次元程度に分割して、クラスベクトルと結合させて各層に入力している。
GeneratorでUpSampleを行う場合の各ResBlockの構造は以下になる。BatchNormに入力する際はバイアスは0センターで投影し、ゲインは1センターで投影を行っている。
DiscriminatorでDownSampleを行う場合の各ResBlockの構造は以下になる。入力と出力のチャンネル数は一致するような構造になっていることに注意。
Orthogonal Regularization (直行正則化) はモデルのパラメータを更新する直前に適用する。
def ortho(model, strength=1e-4, blacklist=[]):
with torch.no_grad():
for param in model.parameters():
# Only apply this to parameters with at least 2 axes, and not in the blacklist
if len(param.shape) < 2 or any([param is item for item in blacklist]):
continue
w = param.view(param.shape[0], -1)
grad = (2 * torch.mm(torch.mm(w, w.t())
* (1. - torch.eye(w.shape[0], device=w.device)), w))
param.grad.data += strength * grad.view(param.shape)
まずはDiscriminatorのResblockを作っていく。
class DBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
downsample=False,
spectral_norm=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels \
if hidden_channels is not None else in_channels
self.downsample = downsample
self.learnable_sc = (in_channels != out_channels) or downsample
self.spectral_norm = spectral_norm
if self.spectral_norm:
self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
else:
self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
self.average_pooling = nn.AvgPool2d(2)
self.activation = nn.ReLU(inplace=True)
# Shortcut layer
if self.learnable_sc:
if self.spectral_norm:
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
else:
self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
def forward(self, x):
x0 = x
# shortcut branch
if self.learnable_sc:
x0 = self.c_sc(x0)
if self.downsample:
x0 = self.average_pooling(x0, 2)
# normal branch
x = self.activation(x)
x = self.c1(x)
x = self.activation(x)
x = self.c2(x)
if self.downsample:
x = self.average_pooling(x, 2)
# residual connection
out = x + x0
return out
Generatorに導入するClass-Conditional BatchNormに関しては、論文中の以下の記述に従う。
We use a single shared class embedding in G, which is linearly projected to produce per-sample gains and biases for the BatchNorm layers. The bias projections are zero-centered, while the gain projections are one-centered. When employing hierarchical latent spaces, the latent vector z is split along its channel dimension into equal sized chunks, and each chunk is separately concatenated to the copy of the class embedding passed into a given block.
class ConditionalBatchNorm2d_with_skip_and_shared(nn.Module):
def __init__(self, num_features, concat_vector_dim, spectral_norm=False):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
if spectral_norm:
self.gain = SNLinear(concat_vector_dim, num_features, bias=False)
self.bias = SNLinear(concat_vector_dim, num_features, bias=False)
else:
self.gain = nn.Linear(concat_vector_dim, num_features, bias=False)
self.bias = nn.Linear(concat_vector_dim, num_features, bias=False)
def forward(self, x, concat_vector):
r"""feed-forward the input feature map and concatenated embedding
Args:
x (Tensor): the input feature map of shape [B, C, H, W]
concat_vector (Tensor): concatenated vector of latent code and class embedding
"""
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) # one-centered gain
bias = self.bias(y).view(y.size(0), -1, 1, 1) # zero-centered bias
out = self.bn(x)
return out * gain + bias
次にGeneratorのResblockに入力される変数が、潜在変数とクラスベクトルが結合されたものを想定して作成していく。
class GBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
upsample=False,
concat_vector_dim=None,
spectral_norm=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels if hidden_channels is not None else out_channels
self.learnable_sc = in_channels != out_channels or upsample
self.upsample = upsample
self.concat_vector_dim= concat_vector_dim
self.spectral_norm = spectral_norm
if self.spectral_norm:
self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
else:
self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1)
self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, 1)
if self.num_classes == 0:
self.b1 = nn.BatchNorm2d(self.in_channels)
self.b2 = nn.BatchNorm2d(self.hidden_channels)
else:
self.b1 = ConditionalBatchNorm2d_with_skip_and_shared(
self.in_channels,
self.concat_vector_dim
)
self.b2 = ConditionalBatchNorm2d_with_skip_and_shared(
self.hidden_channels,
self.concat_vector_dim
)
self.activation = nn.ReLU(True)
# Shortcut layer
if self.learnable_sc:
if self.spectral_norm:
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
else:
self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
def forward(self, x, concat_vector):
x0 = x
# shortcut branch
if self.upsample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
if self.learnable_sc:
x0 = self.c_sc(x0)
# normal branch
## first block
if self.concat_vector_dim is not None:
x = self.b1(x, concat_vector)
else:
x = self.b1(x)
x = self.activation(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
## second block
if self.concat_vector_dim is not None:
x = self.b2(x, concat_vector)
else:
x = self.b2(x)
x = self.activation(x)
x = self.c2(x)
# residual connection
out = x + x0
return out
次にGeneratorを構築していく。なお論文に従って128x128の解像度を前提に作成していく。
class Generator(nn.Module):
def __init__(self, nz=120, ny=128, ngf=64, bottom_width=4, spectral_norm=False, num_classes=0):
super().__init__()
self.nz= nz
self.ny= ny
self.ngf = ngf
self.bottom_width = bottom_width
self.num_classes = num_classes
def forward(self, z, label):
論文へのリンク
[arXiv:1809.11096] Large Scale GAN Training for High Fidelity Natural Image Synthesis
著者・所属機関
Andrew Brock, Jeff Donahue, Karen Simonyan
投稿日時(YYYY-MM-DD)
2018-09-28
1. どんなもの?
SAGANをベースに近年提案されている様々な手法を適用し、また学習時と推論時に異なる分布からサンプルリングを行うTruncation Trickを採用したところ従来のSOTAを大幅に上回るISとFIDを達成した。
GANの分析方法などかなり参考になる。
2. 先行研究と比べてどこがすごいの?
GANの学習を安定化させるための手法は数多く提案されているが、条件付きのImageNetの画像生成では実データのISが233であるのに対し、GANでは未だに52.5しか達成できていない。
本研究では(1)モデルのサイズとバッチサイズを増大させること、(2)生成画像の質と多様性をトレードオフできるTruncation Trickを採用したこと、(3)既存手法を組み合わせることで実データに近いISと飛躍的なFIDの向上を達成した。
3. 技術や手法の"キモ"はどこにある?
3.1 Scaling Up
baselineにはヒンジ損失を最小しているSAGANを使用する。クラス情報はGeneratorに関してはClass-Conditionalなバッチ正則化を使用し、DiscriminatorにはProjection Discriminatorを採用する。モデルにはOrthogonal初期化を使用する。
以下の図のようにバッチサイズやチャンネル数を増大させることでISとFIDの大きな向上が見られる。
これはバッチサイズを大きくすることで学習データにより多くのモードを含めることができるからと考えられる。しかしバッチサイズを増大させることで学習が不安定になりモード崩壊を招いてしまったためこれは後ほど検証する。
またGeneratorでのConditionalバッチ正則化では、入力となるクラスベクトルcは各層で同じものを使用することで計算コストとメモリコストを抑えて学習時間の改善に成功した。
潜在ベクトルzに関してもGeneratorの入力層にのみ代入するのではなく、中間層にも代入し異なる解像度での特徴量にも反映されるようにしている。BigGANではこの潜在ベクトルzを解像度ごとに分割したあとでクラスベクトルcと結合させて各層に入力しているが、BigGAN-deepでは潜在ベクトルzをそのまま使用している。この設計によって学習時間と精度の改善に成功した。
3.2 Truncation Trick
本来ならGANの潜在ベクトルは任意の分布からサンプリングできるが、実際には入力となる潜在ベクトルの分布は一様分布あるいはガウス分布が採用されており、これが最適がどうかは後ほど検証する。
潜在ベクトルzは学習時にガウス分布からサンプリングし、推論時にはtruncatedな(切断)正規分布からサンプリングすることでISとFIDの向上が見られた(切断正規分布のイメージはコメントを参照)。
この打ち切る範囲を変化させることで生成される画像の質と多様性をトレードオフにすることが可能となる。
この閾値を変化させることでISとFIDに関してPR曲線のような質と多様性のトレードオフの関係を捉えることができている。クラス条件付きモデルにおいて、ISは生成画像の多様性を評価できないため閾値を低くすれば大きくスコアが改善されるが、FIDは多様性を評価するために閾値をゼロにしてしまうとスコアが多くく減少していることがわかる。
この手法の問題点は、学習時と推論時で事前分布が異なることであり、大きなモデルを学習させた際には以下のようにTruncated normalからサンプリングした潜在ベクトルzに対してうまく画像生成ができていない。
この問題に対処するには、Generatorを平滑化させてTruncated normalに対して適応できるようにすること、つまり潜在ベクトルzの全領域から良質な画像を生成できることが必要だと考え、Orthogonal正則化を適用した。
これは重み行列が直交行列になるように制限をかけることで、何度も行列積を行ったとしてもノルムが変化しない、つまり勾配消失や勾配爆発が抑えられると期待している。
しかしこの式は正則化効果が非常に強いため、制限のゆるい新たな正則化を提案した。これは重み行列の対角要素以外の要素が大きくならないように制限をかけることで、直接直交行列になるような制限よりも精度が向上すると考えられる。
4. どうやって有効だと検証した?
上記の手法によりモデルは改善されたが、途中で学習が崩壊してしまう現象も発生しておりEealy Stoppingを行う必要がある。まずはこの原因を調査する。
4.1 Analysis
4.1.1 Generator
まず以下では小さなモデルでは生じなかった不安定性が大きなモデルで発生した原因を見ていく。手法としてはObena(2018)で行われていたように、学習中の重み・勾配・損失を見ていき崩壊との関連性を探った。
結果としてはAlrnoldiの反復法で得られた第3位までの特異値が不安定性に関して優位な情報を有しており、以下の図のようにGeneratorの一部の層(たいてい最初の層)が不安定性を引き起こしていると考えられる。
この現象が不安定性の原因なのか単なる症状なのか検証するために、Generatorにスペクトルの過度な増大が発生しないような制限を加えた。
この手法をスペクトル正規化と同時に使用した場合でもそうでない場合でも、多少精度の改善や学習の安定化に貢献しているが、完全に不安定性を防いでいるわけではない。
4.1.2 Discriminator
Generatorと同様の分析を行った。Generatorと異なる殆どの層のスペクトルにはノイズが混じっており、不安定性が発生した場合も特異値は爆発しているわけではなく、多少値がジャンプしているのみである。
途中途中のスパイクはDiscriminatorが大きな勾配を受け取っていることを示すが、AppendixFによるとフロベニウスノルムは増大していないため、この現象は上位の特異値のみに発生している現象である。
この現象はGeneratorがDiscriminatorを大きく摂動させるような画像を生成してしまう敵対的学習手法によるものだと仮定し、Discriminatorのヤコビアンに対して陽に正則化を適用した。
結果としては正則化の効果をγ=10によることで学習は安定化しGeneratorとDiscriminator両方の重み行列のスペクトルは平滑化したが大きく精度は減少してしまい、正則化項をγ=1に設定しても精度が大きく減少し、同様に現象はOrthogonal正則化やDropout、L2正則にも見られた。
またDiscriminatorの損失値が0に近い場合でも不安定性が発生した場合には大きく損失が増大していることも判明した。これはDiscriminatorが学習データを完全に覚えてしまうことで汎化性能が減少してしまっているからと考えられる。
以上を踏まえると、現状では学習を安定化させながら精度を保つ手法は存在せず、学習を安定化させるためにはある程度精度を犠牲にするか、学習をEarly Stoppingなどで監視する必要がある。
4.2 Result
各手法を用いた場合の評価指標を比較した。順に(1)Trancation Trackを使用しない、(2)FIDを最小化した(多様性を犠牲)、(3)検証データに対するIS、(4)ISを最大化した場合を比較している。
5. 議論はあるか?