raghakot / keras-resnet

Residual networks implementation using Keras-1.0 functional API
Other
1.39k stars 617 forks source link

AssertionError running resnet() on Theano backend #6

Closed thisisandreeeee closed 7 years ago

thisisandreeeee commented 8 years ago

I've been trying to use resnet for an image classification task, and am encountering the following assertion error:

Traceback (most recent call last):
  File "resnet.py", line 270, in <module>
    callbacks = [earlyStopping])
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/keras/engine/training.py", line 1026, in fit
    self._make_test_function()
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/keras/engine/training.py", line 695, in _make_test_function
    **self._function_kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/keras/backend/theano_backend.py", line 541, in function
    return Function(inputs, outputs, updates=updates, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/keras/backend/theano_backend.py", line 527, in __init__
    **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/compile/function.py", line 320, in function
    output_keys=output_keys)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/compile/pfunc.py", line 479, in pfunc
    output_keys=output_keys)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/compile/function_module.py", line 1776, in orig_function
    output_keys=output_keys).create(
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/compile/function_module.py", line 1456, in __init__
    optimizer_profile = optimizer(fgraph)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/gof/opt.py", line 101, in __call__
    return self.optimize(fgraph)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/gof/opt.py", line 89, in optimize
    ret = self.apply(fgraph, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/gof/opt.py", line 230, in apply
    sub_prof = optimizer.optimize(fgraph)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/gof/opt.py", line 85, in optimize
    self.add_requirements(fgraph)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/tensor/opt.py", line 1438, in add_requirements
    fgraph.attach_feature(ShapeFeature())
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/gof/fg.py", line 621, in attach_feature
    attach(self)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/tensor/opt.py", line 1262, in on_attach
    self.on_import(fgraph, node, reason='on_attach')
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/theano/tensor/opt.py", line 1304, in on_import
    assert d.dtype in theano.tensor.discrete_dtypes, (node, d.dtype)
AssertionError: (AbstractConv2d{border_mode='valid', subsample=(1.0, 1.0), filter_flip=True, imshp=(None, None, None, None), kshp=(256, 64, 1, 1)}(Subtensor{::, ::, :int64:, :int64:}.0, HostFromGpu.0), 'float64')

I've changed the input shape and output dimension of the resnet() function, and left the other helper functions untouched, and the way I am using resnet() is as follows:

model = resnet()
model.compile(loss='categorical_crossentropy',
              optimizer=sgd, # custom optimizer
              metrics=['accuracy'])
datagen = ImageDataGenerator(
        featurewise_center=False, 
        samplewise_center=False, 
        featurewise_std_normalization=False, 
        samplewise_std_normalization=False, 
        zca_whitening=False,  
        rotation_range=5,  
        width_shift_range=0, 
        height_shift_range=0,
        horizontal_flip=False, 
        vertical_flip=False)  

datagen.fit(X_train)
model.fit_generator(datagen.flow(X_train, Y_train,
                batch_size=batch_size),
                samples_per_epoch=X_train.shape[0],
                nb_epoch=nb_epoch,
                validation_data=(X_val, Y_val),
                callbacks = [earlyStopping])

The error is being thrown at the model.fit_generator line, and it might be useful to note that I've tested the same code on much simpler conv nets.

Any thoughts on how I may proceed dealing with this error?

entron commented 8 years ago

Same error. It is related to python2 and 3 differences. I solved it by changing the corresponding 2 lines in _shortcut into:

    stride_width = input._keras_shape[2] // residual._keras_shape[2]
    stride_height = input._keras_shape[3] // residual._keras_shape[3]
raghakot commented 7 years ago

Thanks. Fixed.