keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.7k stars 19.43k forks source link

Reshape error when one of the target dimension is -1 #4835

Closed jrao1 closed 7 years ago

jrao1 commented 7 years ago

I realize this feature is not in the document, but the code is structured to support this. However it doesn't quite work, test script:

from keras.layers import core
from keras.utils.test_utils import layer_test

layer_test(core.Reshape,
           kwargs={'target_shape': (-1, 1)},
           input_shape=(3, 2, 4))

Result:

Traceback (most recent call last):
  File "test.py", line 6, in <module>
    input_shape=(3, 2, 4))
  File "E:\Projects\AI\DeepLearning\keras\keras\utils\test_utils.py", line 77, i
n layer_test
    y = layer(x)
  File "E:\Projects\AI\DeepLearning\keras\keras\engine\topology.py", line 569, i
n __call__
    self.add_inbound_node(inbound_layers, node_indices, tensor_indices)
  File "E:\Projects\AI\DeepLearning\keras\keras\engine\topology.py", line 632, i
n add_inbound_node
    Node.create_node(self, inbound_layers, node_indices, tensor_indices)
  File "E:\Projects\AI\DeepLearning\keras\keras\engine\topology.py", line 164, i
n create_node
    output_tensors = to_list(outbound_layer.call(input_tensors[0], mask=input_ma
sks[0]))
  File "E:\Projects\AI\DeepLearning\keras\keras\layers\core.py", line 354, in ca
ll
    return K.reshape(x, (-1,) + target_shape)
  File "E:\Projects\AI\DeepLearning\keras\keras\backend\theano_backend.py", line
 565, in reshape
    return T.reshape(x, shape)
  File "d:\programming\theano\theano\theano\tensor\basic.py", line 4722, in resh
ape
    newshape = as_tensor_variable(newshape)
  File "d:\programming\theano\theano\theano\tensor\basic.py", line 212, in as_te
nsor_variable
    raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
theano.tensor.var.AsTensorError: ('Cannot convert (-1, None, 8, 1) to TensorType
', <type 'tuple'>)

Seems to me line 353 of the core.py should be changed to:

target_shape = self.get_output_shape_for(input_shape)[1:]

jrao1 commented 7 years ago

This is fixed in core now.