keras-team / keras-applications

Reference implementations of popular deep learning models.
Other
2k stars 907 forks source link

Unable to define the number of output classes in VGG16/19 when 'include_top'=False #53

Closed 2sang closed 5 years ago

2sang commented 5 years ago

Instantiated VGG16/19 model with the argument 'include_top'=False gives constant output shape (1, 1, 512) Although I provided 'classes'=10 along with them.

import keras
from keras.applications import vgg16

#  Use (-1, 32, 32, 3) CIFAR-10 images as input dataset, for example.
(train_x, train_y), (test_x, test_y) =\
      keras.datasets.cifar10.load_data()  

vgg = vgg16.VGG16(include_top=False,
                  weights=None,
                  input_shape=(32, 32, 3),
                  classes=10)
vgg.summary()

I Expected (None, 1, 1, 10) as an output shape, but the model gives: (None, 1, 1, 512)

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 32, 32, 3)         0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 32, 32, 64)        1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 32, 32, 64)        36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 16, 16, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 16, 16, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 16, 16, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 8, 8, 128)         0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 8, 8, 256)         295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 8, 8, 256)         590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 8, 8, 256)         590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 4, 4, 256)         0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 4, 4, 512)         1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 2, 2, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 2, 2, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 2, 2, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 2, 2, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 1, 1, 512)         0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

I'm not sure if this is an expected behaviour, but seems like the argument classes cannot affects the model when the include_topflag is off. https://github.com/keras-team/keras-applications/blob/4cef2452d27375e3a6c28ae89118174c72473ac2/keras_applications/vgg19.py#L187-L198

If it is okay, I'd like to make a PR with some corresponding test code. Thanks for the wonderful work by the way, will look forward to your comment. :)

2sang commented 5 years ago

Sorry, classes should not be specified when include_top=False, as the comment says.