PaddlePaddle / PaddleX

All-in-One Development Tool based on PaddlePaddle(飞桨一站式全流程开发工具)
Apache License 2.0
4.78k stars 937 forks source link

添加骨干网络 #1026

Open MichaelZhero opened 3 years ago

MichaelZhero commented 3 years ago

问题类型:自己添加骨干网络vgg16,有相关定义的教程嘛?

PaddleX版本
2.0.0

问题描述

======================== 想要添加paddlex/cv/nets/vgg16.py,目前是直接使用的paddle.vision.vgg16()直接进行加载,但报错了 Traceback (most recent call last): File "E:/github/PaddleX-develop-0601/PaddleX-develop/tutorials/train/object_detection/faster_rcnn_vgg16.py", line 54, in model.train( File "E:\github\PaddleX-develop-0601\PaddleX-develop\paddlex\cv\models\faster_rcnn.py", line 357, in train self.build_program() File "E:\github\PaddleX-develop-0601\PaddleX-develop\paddlex\cv\models\base.py", line 105, in build_program self.train_inputs, self.train_outputs = self.build_net(mode='train') File "E:\github\PaddleX-develop-0601\PaddleX-develop\paddlex\cv\models\faster_rcnn.py", line 230, in build_net model_out = model.build_net(inputs) File "E:\github\PaddleX-develop-0601\PaddleX-develop\paddlex\cv\nets\detection\faster_rcnn.py", line 223, in build_net body_feat_names = list(body_feats.keys()) AttributeError: 'Variable' object has no attribute 'keys'

经查询,key为每个骨干网络下res{}._sum定义,其他网络结构对应的是多尺寸的特征层,而VGG16只有一个特征层

‘’‘ python def call(self, input): assert isinstance(input, Variable) if isinstance(self.feature_maps, (list, tuple)): assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \ "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps) res_endpoints = [] res = input feature_maps = self.feature_maps out = self.net(input) if self.num_classes or self.feature_maps == "stage4": return out for i in feature_maps: res = self.end_points[i - 2] if i in self.feature_maps: res_endpoints.append(res) if self.freeze_at >= i: res.stop_gradient = True return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) for idx, feat in enumerate(res_endpoints)]) '''

FlyingQianMM commented 3 years ago

自定义backbone的话,需要参考paddlex的代码结构自行修改paddlex的源码。 如果需要使用paddle.vision.vgg16(),可以参考paddle.vision.vgg16()源码实现,将这部分加入到paddlex的源码中。