rcmalli / keras-squeezenet

SqueezeNet implementation with Keras Framework
MIT License
404 stars 147 forks source link

Cant retrain the model with new classification? #4

Open ghost opened 7 years ago

ghost commented 7 years ago

I am trying to retrain it with new classes because it doesn't have any proper docs I made many assumptions to retrain it. Basically I am trying to use it for two objects classification so i changed nb_classes to 2 and tried to retrain it. But its giving this error

ValueError: Cannot reshape input of shape (1000,) to shape [1 2 1 1]
Apply node that caused the error: Reshape{4}(conv10_b, TensorConstant{[1 2 1 1]})

This is the full code

#default model and weight of theano
model = get_squeezenet(2, dim_ordering='th')
.
.
y=np.array([[1,0]])
model.fit(im,y, batch_size=1, nb_epoch=1)
davinnovation commented 7 years ago

model = SqueezeNet(weights=None, classes=2) should be work

ghost commented 7 years ago

The new update solved the problem but now have other additional problems

import numpy as np
from keras_squeezenet import SqueezeNet
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
from keras.preprocessing import image

model = SqueezeNet(weights=None, classes=2)

img = image.load_img('dog.jpg', target_size=(200, 200))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
print('Predicted:', decode_predictions(preds))

Its about the image size error

ValueError: Error when checking : expected input_1 to have shape (None, 227, 227, 3) but got array with shape (1, 200, 200, 3)

I know you are not the author but thank you for the reply

davinnovation commented 7 years ago

@potholiday there's two solutions.

  1. resize an image to 227,227
  2. modify keras-squeezenet code line 58 : input_shape = _obtain_input_shape(input_shape, default_size=227, min_size=48, data_format=K.image_data_format(), include_top=True) to input_shape = _obtain_input_shape(input_shape, default_size=200, min_size=48, data_format=K.image_data_format(), include_top=True)

I think the first solution is easier