tensorflow / models

Models and examples built with TensorFlow
Other
77.01k stars 45.78k forks source link

[BUG] When changing the batch_size of the official ResNet model, raised dimension mismatch error. #5291

Closed JustinhoCHN closed 6 years ago

JustinhoCHN commented 6 years ago

System information

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with

python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

Describe the problem

In the official resnet code(cifar10) , the default batch_size is 64, since the images' size in cifar10 is small, we can use bigger batch_size, but when I want to change the batch_size to 16, it'll raise the dimension mismatch error:

InvalidArgumentError (see above for traceback): Incompatible shapes: [64] vs. [16] [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ArgMax, ArgMax_1)]] [[Node: cross_entropy/_1585 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4572_cross_entropy", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

After debugging, I found that the 'final_size' of the resnet was hard-coded, in resnet_model.py line 472:

inputs = tf.reshape(inputs, [-1, self.final_size])

-1 means the batch_size, the batch_size is calculated by 4096 // self.final_size, it must be match with the default batch_size(64), so the self.final_size was hard-coded to 64 (4096//batch_size=64), as you can see in cifar10_main.py line 187:

final_size=64,

But if you want to change the batch_size, you have to also change the final_size of the resnet, otherwise it'll raise dimension mismatch error, but there's no notes to remind us to do that!

Advice

Don't hard-coded the final_size, we can calculate the final_size before we put that in model building function,

_BATCH_SIZE = 32   # define the batch_size first
class Cifar10Model(resnet_model.Model):
    """Model class with appropriate defaults for CIFAR-10 data."""

    def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
                 version=resnet_model.DEFAULT_VERSION):
        """These are the parameters that work for CIFAR-10 data.
    """
        if resnet_size % 6 != 2:
            raise ValueError('resnet_size must be 6n + 2:', resnet_size)

        num_blocks = (resnet_size - 2) // 6

        super(Cifar10Model, self).__init__(
            resnet_size=resnet_size,
            bottleneck=False,
            num_classes=num_classes,
            num_filters=16,
            kernel_size=3,
            conv_stride=1,
            first_pool_size=None,
            first_pool_stride=None,
            second_pool_size=8,
            second_pool_stride=1,
            block_sizes=[num_blocks] * 3,
            block_strides=[1, 2, 2],
            final_size= int(4096/_BATCH_SIZE), # And calculate the final_size using _BATCH_SIZE
            version=version,
            data_format=data_format)

Source code / logs

Traceback (most recent call last): File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call return fn(*args) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn options, feed_dict, fetch_list, target_list, run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun run_metadata) tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [64] vs. [16] [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ArgMax, ArgMax_1)]] [[Node: cross_entropy/_1585 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4572_cross_entropy", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

During handling of the above exception, another exception occurred: Traceback (most recent call last): File "main.py", line 327, in main(argv=sys.argv) File "main.py", line 322, in main shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) File "/home/jto/projects/dogs_cats_tf/official/resnet/resnet_run_loop.py", line 396, in resnet_main max_steps=flags.max_train_steps) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model return self._train_model_default(input_fn, hooks, saving_listeners) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 859, in _train_model_default saving_listeners) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1059, in _train_with_estimatorspec , loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 567, in run run_metadata=run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1043, in run run_metadata=run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1134, in run raise six.reraise(original_exc_info) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/six.py", line 686, in reraise raise value File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1119, in run return self._sess.run(args, *kwargs) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1191, in run run_metadata=run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 971, in run return self._sess.run(args, **kwargs) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run run_metadata_ptr) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1135, in _run feed_dict_tensor, options, run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run run_metadata) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [64] vs. [16] [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ArgMax, ArgMax_1)]] [[Node: cross_entropy/_1585 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4572_cross_entropy", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'Equal', defined at: File "main.py", line 327, in main(argv=sys.argv) File "main.py", line 322, in main shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) File "/home/jto/projects/dogs_cats_tf/official/resnet/resnet_run_loop.py", line 396, in resnet_main max_steps=flags.max_train_steps) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model return self._train_model_default(input_fn, hooks, saving_listeners) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 856, in _train_model_default features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 831, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "main.py", line 303, in dogscats_model_fn multi_gpu=params['multi_gpu']) File "/home/jto/projects/dogs_cats_tf/official/resnet/resnet_run_loop.py", line 281, in resnet_model_fn tf.argmax(labels, axis=1), predictions['classes']) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/ops/metrics_impl.py", line 407, in accuracy is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2529, in equal "Equal", x=x, y=y, name=name) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op op_def=op_def) File "/home/jto/anaconda3/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1718, in init self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [64] vs. [16] [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ArgMax, ArgMax_1)]] [[Node: cross_entropy/_1585 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4572_cross_entropy", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

karmel commented 6 years ago

Thanks for the report @JustinhoCHN -- would you like to submit a PR to fix?

JustinhoCHN commented 6 years ago

@karmel After double check, I found that this problem won't occur in the newest release, I remembered that this bug was found in the 1.8.0 version, do we still need to submit the PR for that version?

karmel commented 6 years ago

No, if it is fixed, we can leave as-is. Thanks-- closing this.