shimopino / papers-challenge

Paper Reading List I have already read
30 stars 2 forks source link

Neural Discrete Representation Learning #105

Open shimopino opened 4 years ago

shimopino commented 4 years ago

論文へのリンク

[arXiv:1711.00937] Neural Discrete Representation Learning

著者・所属機関

Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu

投稿日時(YYYY-MM-DD)

2017-11-02

1. どんなもの?

VQ-VAE

通常のVAEと異なり、事前分布をPixelCNNなどの自己回帰型のモデルで学習させておき、Encodeした特徴マップを離散ベクトルに変換することで、課題となっていた事後分布の崩壊を防ぎ、かつ高品質な画像を生成することに成功した。

2. 先行研究と比べてどこがすごいの?

本研究では従来は連続変数として表現されていた潜在ベクトルを離散変数として表現することを目的にしており、離散値のほうが自然言語や画像を表現するのに適していると考えている。

本研究では潜在変数を離散化することで、潜在変数が大きく変動していまう現象を防ぎ、これによりDecoderが強力になりすぎてしまい潜在変数を無視してしまう事後分布の崩壊を防ぐことになる。

また離散化させたVAEであるVQ-VAEが、連続変数を使用した対数尤度型のモデルと同程度の能力を有していることを示した。

3. 技術や手法の"キモ"はどこにある?

VAEでは直交共分散の正規分布を仮定しており"reparametrization trick"を使用することで勾配の計算を可能にしている。

VQ-VAEでは事後分布も事前分布もカテゴリ分布を仮定しており、サンプルはこの分布からのインデックスとして取り出す。

3.1 Discrete Latent variables

本モデルでは潜在空間をK個のD次元の離散ベクトルで表現している。入力されたデータxをEncoderを使用して特徴マップを生成し、潜在空間上での離散化されたベクトルとの距離を計算し最も近い離散ベクトルに変換する。

最近傍のベクトルで変換するために得られる分布は決定論的な分布となる。つまり潜在変数zに対して一様分布からサンプリングすると、KL Divergenceの定数が得られ、これはlogKに等しくなる。

image

Decoderへの入力はこの特徴マップを離散化ベクトルに変換した新たな潜在変数になる。

image

3.2 Learning

Decoder側からEncoder側への勾配を計算することができないために、Decoder側の入力への勾配をEncoder側の出力への勾配に近似している。

image

損失関数は3つの要素で構成される。1つ目は再構成損失であるが途中の勾配をスキップするため、離散ベクトルは勾配を受け取ることがないため最適化させることができない。

そこで離散ベクトルを最適化させるためにVQを学習させるためのアルゴリズムを導入する。これは離散ベクトルとEncoder側で出力される特徴マップとのL2損失を計算することで、離散ベクトルが可能な限り特徴マップに近づくように学習を行うことが可能となる。

なお離散ベクトルには、Encoder側の特徴マップの移動平均の関数として更新を行う。またEncoder側も離散ベクトルの変化に合うように3つ目の項を追加している。

image

spは勾配を伝播させないstop-gradientを意味している。これでそれぞれの項に対しDecoderは1つ目の項で最適化を行い、離散ベクトルは2つ目の項で最適化を行い、Encoderは3つ目の項で最適化を行っていく。

4. どうやって有効だと検証した?

5. 議論はあるか?

shimopino commented 4 years ago

離散ベクトルの更新には移動平均(EMA)を使用することもできる。

Encoderから出力された特徴マップと対応する最近傍の離散ベクトル間での損失関数は以下のように記述できる。

image

最近傍の計算にはKMeansを使用することもできるが、ミニバッチに対して直接パラメータの更新を行うことができないのでEMAを使用してオンラインで更新できるようにしている。

image

shimopino commented 4 years ago

特徴マップでは以下のようにチャンネル方向のベクトル対して量子化の処理を行う。

GAN-2

# D=64, K=10
emb_dim, num_emb = 64, 10

# inputs: [B, C=D, H, W] -> [B, H, W, C=D]
inputs = inputs.permute(0, 2, 3, 1).contiguous()

# flatten feature map: [B, H, W, D] -> [BxHxW=N, D]
flatten = inputs.view(-1, emb_dim)

# embedding weight: [D, K]
embeddings = torch.randn(emb_dim, num_emb)
shimopino commented 4 years ago

次に特徴マップのN個のD次元ベクトルと、参照する辞書に格納されているK個のD次元ベクトルとの距離を計算する。

GAN-3

# flatten:   [N, D]
# embedding: [D, K]
distance = (
    flatten.pow(2).sum(dim=1, keepdim=True)
    -2 * flatten @ embeddings
    + embeddings.pow(2).sum(dim=0, keepdim=True)
)
shimopino commented 4 years ago

これでN個のベクトルとK個のベクトルの全パターンでの距離が格納されているNxK次元の行列を計算することができる。あとはK方向に距離が最小となるIndexを計算し、Embeddingを使用してIndexを辞書となるベクトルに変換すればいい。

GAN-4

# embeddings_idx: [N, K] -> [N, ]
embeddings_idx = torch.argmin(distance, dim=1)

# reshape embeddings_idx: [N, ] --> [B, H, W, ]
# this should be doen before applying Embedding Func
embeddings_idx = embeddings_idx.view(*inputs.size()[:-1])

# quantize: [B, H, W, ]x[K, D] -> [B, H, W, D]
quantize = F.embedding(embeddings_idx, embeddings.transpose(0, 1))

# reshape to the same shape of inputs
# quantize:  [B, H, W, D]-> [B, H, W, D=C]
quantize = quantize.view(*inputs.size())
shimopino commented 4 years ago

これでEncoderから出力された特徴マップと、量子化した後の特徴マップが得られるので、この特徴マップ間の距離を計算し最小化できるように学習を進めていくことが可能となる。

GAN-5

e_latent_loss = F.mse_loss(quantize.detach(), inputs)
q_latent_loss = F.mse_loss(inputs.detach(), quantize)

loss = q_latent_loss + commitment * e_latent_loss

# Decoder側の量子化された特徴マップに対する勾配をそのままEncoder側の入力に渡す。
# 逆伝播用の計算グラフを構築しないdetach()を使用する。
quantize = inputs + (quantize - inputs).detach()
shimopino commented 4 years ago

辞書の各ベクトルに関しては、以下のように参照元の特徴マップ内のベクトルの平均として計算することができる。しかし、ミニバッチによる学習では限られたデータセットに対してのみ参照を計算することになってしまう。

GAN-6

そこで過去の計算結果にも参照できるように指数移動平均を使用する。おおよそβ=0.99に設定することで過去100回程度の値を参照する形に変更できる。

shimopino commented 4 years ago

まず辞書の各ベクトルの参照回数の指数移動平均を計算する。OneHotベクトルを利用することで簡単に対象のミニバッチ内での参照回数を計算できる。ゼロ頻度のベクトルが存在しても問題ないようにラプラス平滑化を行っている。

GAN-7

# embeddings_onehot: [N, ] --> [N, K] 
embeddings_onehot = F.onehot(embeddings_idx, num_classes=num_emb)
# reference_counts: [N, K] --> [K, ]
ref_counts = torch.sum(embeddings_onehot, dim=0)

# EMA for reference counts
ema_ref_counts = beta * ema_ref_counts + (1-beta)*ref_counts

# total reference counts
n = ema_ref_counts.sum()

# laplace smoothing
ema_ref_counts = ((ema_ref_counts + eps)
                  /(n + ema_ref_counts*eps)) * n
shimopino commented 4 years ago

最後にOneHotベクトルを使用して、N個のベクトルの中から、特定のIndexに対応する量子化前のベクトルの総和を計算する。その後に計算済みの参照回数を使用して平均値として算出する。

GAN-8

# dw: [D, N] x [N, K] --> [D, K]
dw = flatten.transpose(0, 1) @ embeddings_onehot

# ema for embeddings: [D, K]
ema_embeddings = beta * ema_embeddings + (1-beta)*dw

# normalize by reference counts
# [D, K] / [1, K]
embeddings = ema_embeddings / ema_ref_counts.unsqueeze(0)
shimopino commented 4 years ago

あとは性能評価のためにPerplexityを計算すればいい。

# avg_probs: [N, K] --> [K, ]
avg_probs = torch.mean(embeddings_onehot, dim=0)

# 2^{entropy}
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))