google / prettytensor

Pretty Tensor: Fluent Networks in TensorFlow
1.24k stars 151 forks source link

batch normalization in conv2d makes assertion error #24

Closed jhlee525 closed 8 years ago

jhlee525 commented 8 years ago

Whenever I inserted batch normalization in prettytensor's conv2d function, it fails assertion.

My simple test code is

import tensorflow as tf
import prettytensor as pt

x = tf.placeholder(tf.float32, [None, 224, 224, 3])

net = pt.wrap(x)
net = net.conv2d(7, 64, batch_normalize=True)

Then the console log prints

 File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/prettytensor/pretty_tensor_class.py", line 1980, in method
    result = func(non_seq_layer, *args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/prettytensor/pretty_tensor_image_methods.py", line 246, in __call__
    y = input_layer.with_tensor(y).batch_normalize()
  File "/usr/local/lib/python2.7/dist-packages/prettytensor/pretty_tensor_class.py", line 1980, in method
    result = func(non_seq_layer, *args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/prettytensor/pretty_tensor_image_methods.py", line 59, in __call__
    assert isinstance(learned_moments_update_rate, tf.compat.real_types)
AssertionError

Thank you

eiderman commented 8 years ago

Thanks for bringing this to my attention. The problem is that no appropriate defaults were set for a couple of parameters to batch_normalize and the only way to pass them in is to use defaults_scope.

The workaround for now is to set them in a defaults_scope:

with pt.defaults_scope(learned_moments_update_rate=0.0003, variance_epsilon=0.001): x = tf.placeholder(tf.float32, [None, 224, 224, 3]) net = pt.wrap(x) net = net.conv2d(7, 64, batch_normalize=True)

eiderman commented 8 years ago

Defaults added and pt.BatchNormalizationArguments lets you customize the values passed through conv2d. Please reopon if this doesn't work for you