YutaroOgawa / pytorch_advanced

書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の実装コードを配置したリポジトリです
MIT License
835 stars 337 forks source link

第3章 PSPNet p.146のstrideについて #48

Open atsushi-tmdu opened 4 years ago

atsushi-tmdu commented 4 years ago

Resnetの部分について質問です。原著論文ではfeature_dilated_res_1, feature_dilated_res_2のstrideは2であり、https://github.com/hszhao/semseg/blob/master/model/resnet.pyでもstride 2で作成されていましたが、あえてここのstrideを変更された理由がありましたらご教授いただけると幸いです。あるいは自分、何かを勘違いしているのでしょうか...

YutaroOgawa commented 4 years ago

@atsushi-tmdu さま

非常に重要なご質問をありがとうございます。

また、丁寧に丁寧に本書や他論文などを読み込んでのご質問を賜り、 誠にありがとうございます。

【1】hszhao氏のリポジトリsemsegのmodelでは、 https://github.com/hszhao/semseg/blob/master/model/pspnet.py#L52 および、L57を見ると、

self.layer3とself.layer4(私のfeature_dilated_res_1,2に相当)の strideは1に上書き設定されております。

【2】chainerのpspnetにおいても https://github.com/chainer/chainercv/blob/v0.13.1/chainercv/experimental/links/model/pspnet/pspnet.py#L103 およびL106を見ると、

self.res4とself.res5(私のfeature_dilated_res_1,2に相当) もstrideは1になっております。

【3】hszhao氏のリポジトリsemsegのmodelでは、 最初に、 https://github.com/hszhao/semseg/blob/master/model/pspnet.py#L48 にて、

self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

として、resnetのlayer3,4を与えています。

これらのresnetのlayersは、torchvisionの公式モデルであり、 https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L149 および、L151の通り、 strideが2となっています(通常のResNet)。

hszhao氏は、PSPNetで使用する際に、この通常ResNetをロードしてから

https://github.com/hszhao/semseg/blob/master/model/pspnet.py#L52 を再掲するように、

self.layer3とself.layer4のstrideなどを上書きし、loadした通常ResNetから構成を変更しています。

その変更後が本書のモデル構成であり、stride=1となっていることに対応します。

【4】最後に念のため、 https://github.com/hszhao/PSPNet のpspnet50_ADE20K.caffemodelを、単純にPyTorchモデルに変更した、 本書p.134記載の学習済みモデル(pspnet50_ADE20K.pth)を、 feature_dilated_res_1, feature_dilated_res_2をstride=2に変更したモデルに対して、ロードしてみたのですが、元のpspnet50_ADE20K.caffemodeモデルがstride=1なので、stride=1に更新変更されてしまいます。

やはり、stride=1でhszhao氏らは作成していると、彼らの学習済みモデルからも推察されます。

【結論】 @atsushi-tmdu さまのおっしゃる通り、通常のResNetにおいては、 feature_dilated_res_1, feature_dilated_res_2のstrideは2ですが、 PSPNetで使用する際にはstride=1に変更して実装されている、と私は考えております。 そのため本書もそう構成されております。