charlesq34 / pointnet

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
Other
4.66k stars 1.44k forks source link

A little problem about batch norm #195

Closed WMF1997 closed 4 years ago

WMF1997 commented 4 years ago

hello @charlesq34

first, thank you for your code!

I want to run the code. And I met a problem:

problem description

I want to test the code in tf_util.py, and i picked out the conv2d part in that .py file, and I copied and modified the code in transform_nets.py.

and I met the problem here. I think it is the problem of BatchNorm part.

ValueError: Shape must be rank 0 but is rank 1 for 'tconv1/bn/cond/Switch' (op: 'Switch') with input shapes: [1], [1].

the full description of the error:


WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

WARNING:tensorflow:From /home/wmf997/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 0 but is rank 1 for 'tconv1/bn/cond/Switch' (op: 'Switch') with input shapes: [1], [1].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/下载/pointnet-master/utils/tf_util.py in <module>
    590             padding='VALID', stride=[1,1],
    591             is_training=tf.constant([True]), # if bn=True, some error may happen.
--> 592             scope='tconv1', bn=True, bn_decay=None)
    593   s = tf.Session()
    594   s.run(tf.global_variables_initializer())

~/下载/pointnet-master/utils/tf_util.py in conv2d(inputs, num_output_channels, kernel_size, scope, stride, padding, use_xavier, stddev, weight_decay, activation_fn, bn, bn_decay, is_training)
    169         # bn_decay = tf.reshape(bn_decay, [])
    170         outputs = batch_norm_for_conv2d(outputs, is_training,
--> 171                                         bn_decay=bn_decay, scope='bn')
    172 
    173       if activation_fn is not None:

~/下载/pointnet-master/utils/tf_util.py in batch_norm_for_conv2d(inputs, is_training, bn_decay, scope)
    540       normed:      batch-normalized maps
    541   """
--> 542   return batch_norm_template(inputs, is_training, scope, [0,1,2], bn_decay)
    543 
    544 

~/下载/pointnet-master/utils/tf_util.py in batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay)
    484     ema_apply_op = tf.cond(is_training,
    485                            lambda: ema.apply([batch_mean, batch_var]),
--> 486                            lambda: tf.no_op())
    487 
    488     # Update moving average and return current batch's avg and var.

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   2083     if isinstance(pred, bool):
   2084       raise TypeError("pred must not be a Python bool")
-> 2085     p_2, p_1 = switch(pred, pred)
   2086     pivot_1 = array_ops.identity(p_1, name="switch_t")
   2087     pivot_2 = array_ops.identity(p_2, name="switch_f")

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py in switch(data, pred, dtype, name)
    360     pred = ops.convert_to_tensor(pred, name="pred")
    361     if isinstance(data, ops.Tensor):
--> 362       return gen_control_flow_ops.switch(data, pred, name=name)
    363     else:
    364       if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_control_flow_ops.py in switch(data, pred, name)
    861   # Add nodes to the TensorFlow graph.
    862   _, _, _op = _op_def_lib._apply_op_helper(
--> 863         "Switch", data=data, pred=pred, name=name)
    864   _result = _op.outputs[:]
    865   _inputs_flat = _op.inputs

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Shape must be rank 0 but is rank 1 for 'tconv1/bn/cond/Switch' (op: 'Switch') with input shapes: [1], [1].

my environment:

OS: Ubuntu 16.04 (NO GPU DEVICE) TensorFlow 1.13.1 (CPU) (Is my tf version too high to run it, or not? And I want to run it on tf 1.8.0, in another machine with cuda. I do not know if it could run.)

(the cpu machine is used just for reading your source code)

my added code:

# STILL IN tf_util.py

if __name__ == '__main__': # if main
  x = tf.random_normal([4, 10, 3])
  x1 = tf.expand_dims(x, -1)
  y = conv2d(x1, 64, [1,3], 
            padding='VALID', stride=[1,1],
            is_training=tf.constant([True]), # if bn=True, some error may happen. if bn=False, then it works. 
            scope='tconv1', bn=True, bn_decay=None)
  s = tf.Session()
  s.run(tf.global_variables_initializer())
  yy = s.run(y) # yy.shape is [4, 10, 1, 64], alright. 

yours sincerely @wmf1997

WMF1997 commented 4 years ago

perhaps it is the version of tf? in fact, i am now using tf 1.12.0+ (i.e. the tf version i am now using is more than 1.12.0) (only the server of the lab is using tf 1.8.0, and it is quite high.) so... i want to change the code. all the operation defs mentioned in tf_utils. (the defination of a variable on cpu still works fine, and i will not change that) and i think it will work. i will try it later in the day.

yours sincerely @wmf1997

WMF1997 commented 4 years ago

I tried my idea in tf 1.13.0 and it finally works. close the issue.