YutaroOgawa / pytorch_advanced

書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の実装コードを配置したリポジトリです
MIT License
837 stars 336 forks source link

EfficientGANの実装について #201

Open insilicomab opened 2 years ago

insilicomab commented 2 years ago

良書ありがとうございます。いろいろなことを勉強させていただいております。EfficientGANのサンプルコードも動かすことができました。そこで、サンプルコードをベースとしてRGB画像の異常検知を試みております。サンプルコードでは白黒なのでチャネルは1としていましたが、RGBですのでチャネルを3に変更してトライしたのですが、以下のようなエラーが出ました。

`--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) in () 1 # 学習・検証を実行する ----> 2 G_update, D_update, E_update = train_model(G, D, E, dataloader=train_dataloader, num_epochs=opt.num_epochs)

7 frames /usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias) 441 _pair(0), self.dilation, self.groups) 442 return F.conv2d(input, weight, bias, self.stride, --> 443 self.padding, self.dilation, self.groups) 444 445 def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[64, 1, 28, 28] to have 3 channels, but got 1 channels instead`

白黒画像からRGB画像ようにモデルを組み直す際はチャネルの変更、画像サイズなど変更すべきところがあると思うのですが、具体的にどこを修正すればよろしいでしょうか。

YutaroOgawa commented 2 years ago

@insilicomab さま

ありがとうございます。

EfficientGANのファイル https://github.com/YutaroOgawa/pytorch_advanced/blob/master/6_gan_anomaly_detection/6-4_EfficientGAN.ipynb

において、

Generatorの実装 の

        self.last = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=1,
                               kernel_size=4, stride=2, padding=1),
            nn.Tanh())
        # 注意:白黒画像なので出力チャネルは1つだけ

out_channelsを3に

Discriminatorの実装 の

        # 画像側の入力処理
        self.x_layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        # 注意:白黒画像なので入力チャネルは1つだけ

の nn.Conv2d(1, 64, kernel_size=4, を nn.Conv2d(3, 64, kernel_size=4,

にすることになると思います。もしかしたら、他にも修正する箇所があるかもしれませんが、ネットワーク的にはこのあたりかと思います。

あとは、本書のネットワークの形は単純なので、シンプルな白黒画像のMNISTレベルしか対応できず、EfficientGAN論文のカラー画像でのネットワーク構成を使用する方が良いかもしれません。

どうぞよろしくお願い致します。

insilicomab commented 2 years ago

大変ご丁寧な回答ありがとうございます。カラー画像でのネットワーク構成を使用してみます。

YutaroOgawa commented 2 years ago

@insilicomab さま

はい、ありがとうございます。 なかなか難しい取り組みだと思います。

ぜひネット上や他の書籍も参考にいただければと思います。

あっ、思い出しました。

書籍「PyTorchによる画像生成/画像変換のためのGANディープラーニング実装ハンドブック」 https://www.amazon.co.jp/dp/4798062294/

の第8章が異常検知(AnoGAN、EfficientGAN)です。 その 第4節がカラー画像でのEfficientGANの実装解説となっていて、果物の画像で異常検知しています。

こちらも参考になるかと思いました。

どうぞよろしくお願い致します。

insilicomab commented 2 years ago

わざわざご丁寧にありがとうございます。PyTorchによる画像生成/画像変換のためのGANディープラーニング実装ハンドブックのGitHubでのソースコードを見つけましたので、そちらを参考にしたいと思います。感謝しております。