Open mashimashica opened 2 years ago
現在モデルは作り終えた(ほぼtorchvisionにあるのを持ってきただけ)が、入力と出力の次元は要相談。
おそらく入力が全体の観測(O_t)で出力が1-hot representationであるmessage(m_t)になっているはずであるが確認したい。
VGG16のモデルの中身を出力させた。すると結果は以下の通りになった。 これは前にも載せたこの図と同じようになっていることがわかる。
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )
入力画像は,121*121*3です.128*128*3に変形しても良いかもしれません
遅くなってすみません.以下の訓練コードでゲーム環境上でVAEが訓練できることを確認できたので取り急ぎ共有します. 後になって,また修正するつもりです. https://github.com/envzhu/WM2021_LWM/tree/vae
また,以下のcolabノートでGoogle Drive上にソースコードをアップロードすれば,実行できます. ローカルで実行する場合も,上手く依存パッケージをインストールできない場合は以下を参考にして下さい https://colab.research.google.com/drive/175AJFhfgCMBdDPDS7BwktwGBAXwveflF?usp=sharing
訓練状況の確認のために,一定周期ごとにテスト画像がvae_test.pngに出力されます.(https://github.com/mashimashica/WM2021_LWM/issues/1#issuecomment-1094235842
Gumbel Softmaxを使うと,One-hotでも勾配計算ができるらしいです. 論文中のAlgorithm 2に,しれっとGumbelとありますが,このことだと思います.
pytorchの場合は以下を参照.One-hot化するには,hard=Trueとする必要あり(?)
こちらでリファクタリングしたものを共有します https://github.com/envzhu/WM2021_LWM/tree/speaker
@Yoshida0404
Gumbel Softmaxの実装箇所でミスがありました.以下のように修正して下さい.
- ret = y_hard.detach() - y_soft.detach() + y_soft
+ y_hard = y_hard.detach() - y_soft.detach() + y_soft
return y_soft, y_hard
勾配が上手く伝わってないかったのは,これが原因なのだと思います.
CNNを使うということであるが、ここでは既に存在しているモデルを使えば効率良く良い成果を出すことができ、またおそらく元論文の著者も実際に何を使ったかの記載は無いが既存のモデルを使っていると考えられる。
ここで言うCNN関連の既存のモデルといえば、最初期のものでは"LeNet"(1989)や"ImageNet"(2009), それより先のモデルでは2010-2017年に開催されたImageNetを用いた画像認識競技会ImageNet Large Scale Visual Recognition Challenge(ILSVRC)の優勝モデルなど(AlexNet, VGG, ResNetなど)が挙げられる。(これより新しいモデルもあり高性能かもしれないが、わかりやすさや元論文との整合性を捨ててまで使うことはないと考えた)
元論文を見てみると、「CNN, Linear layer, softmaxを使う」とあり、これに過不足無く当てはまっているのはVGGである。(LeNetではsigmoidを使ったりしていて時代遅れ感があり、逆にResNetやそれより先ではSkip Connectionなどの記載をするはずだ、という考えから) それゆえ今回はまず一旦VGGシリーズの中でも一番良く使われているVGG16をBatch Normalization無しで使うことにする。