Closed curogihu closed 4 years ago
efficientnet.pyのコードのb0~b7を一つにまとめてgetattrできないかな?
下記の方向で進めます
if __name__ == '__main__':
model_name = 'efficientnetb3'
model_name_converted = model_name.replace('efficientnetb', 'EfficientNetB')
nb_classes = 10
height = 600
width = 800
with tf.device("/cpu:0"):
base_model = getattr(en, model_name_converted)(
include_top=False,
input_shape=(height, width, 3),
weights='imagenet'
)
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(nb_classes, activation='softmax')(x)
model = Model(base_model.input, predictions)
print(model.summary())
efficientnetのコード集約含めて、一通り手直し済み。
EfficientNetB0からEfficientNetB7まで追加。 動作テストとしてEfficientNetモデルB0 - B7をexampleフォルダで使用し、正常動作したことを確認。