naturomics / CapsNet-Tensorflow

A Tensorflow implementation of CapsNet(Capsules Net) in paper Dynamic Routing Between Capsules
Apache License 2.0
3.8k stars 1.17k forks source link

Set batch = 1 error #22

Closed MilesZhao closed 6 years ago

MilesZhao commented 6 years ago

When I set batch as 1, there is a value error when building the graph. I think it was caused the this line of code.

MilesZhao commented 6 years ago

error info in detail.

Traceback (most recent call last): File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py", line 671, in _call_cpp_shape_fn_impl input_tensors_as_shapes, status) File "/usr/lib/python3.5/contextlib.py", line 66, in exit next(self.gen) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 2 but is rank 3 for 'Masking/MatMul' (op: 'MatMul') with input shapes: [10,16], [1,10,1].

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "tf_orginal_CapsNet.py", line 345, in capsNet = CapsNet(is_training=is_training) File "tf_orginal_CapsNet.py", line 238, in init self.build_arch() File "tf_orginal_CapsNet.py", line 287, in build_arch self.masked_v = tf.matmul(tf.squeeze(self.caps2), tf.reshape(self.Y, (batch_size, 10, 1)), transpose_a=True) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 1816, in matmul a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 1217, in _mat_mul transpose_b=transpose_b, name=name) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op op_def=op_def) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op set_shapes_for_outputs(ret) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs shapes = shape_func(op) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring return call_cpp_shape_fn(op, require_shape_fn=True) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn debug_python_shape_fn, require_shape_fn) File "/home/is-lab/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 2 but is rank 3 for 'Masking/MatMul' (op: 'MatMul') with input shapes: [10,16], [1,10,1].

naturomics commented 6 years ago

This is due to the using of the tf.squeeze in this line:

self.masked_v = tf.matmul(tf.squeeze(self.caps2), tf.reshape(self.Y, (batch_size, 10, 1)), transpose_a=True)

But, who will set batch size to 1? If you really want it, explicitly squeeze on caps2 in the last dim, like this:

self.masked_v = tf.multiply(tf.squeeze(self.caps2, axis=-1), tf.reshape(self.Y, (-1, 10, 1)))