tensorflow / skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Apache License 2.0
3.18k stars 439 forks source link

batch normalization implementation is wrong #117

Closed raingo closed 8 years ago

raingo commented 8 years ago

The running mean and variance are tensors, which can't be saved in snapshots. In the testing time, the batch mean and variance are randomly initialized.

raingo commented 8 years ago

I just figured this out by myself. It's not wrong, but lack of clear documentation.

I believe the current implementation is fine, but have to take special care with how to construct the saver.

If saver is constructed by saver = tf.Saver(tf.all_variables()), there is no problem at all.

If the saver is constructed by saver = tf.Saver(tf.trainable_variables()), there is a problem with not saving/loading the shadow variables in the moving average class. The hacky and buggy solution:

tvars = tf.trainable_variables()
ema_name = 'ExponentialMovingAverage:0'
ema_vars = [var for var in tf.all_variables()
  if var.name.endswith(ema_name)]
self._saver = tf.train.Saver(
  tvars + ema_vars)

Please correct me if I am wrong. I hope this will help to improve the docs of the batch_norm function.

ilblackdragon commented 8 years ago

Why would you want to save only trainable variables? Also estimator.save already should be taking care of saving - why do you need your own saver?

raingo commented 8 years ago

The reason to save only the trainable variables is to save disk space. For CNN, the shadow variables in the optimizer triple the checkpoint size. It is also seen in many tensorflow examples.

ilblackdragon commented 8 years ago

Disk space is dirt cheap these days... But then you should just filter the Adam/* or Adagrad/* variables when saving from all_variables, instead of trying to collect all others - because your code above is missing a lot of other things (like global_step, any random projections or random embeddings used in some models, etc).

I'm closing this issue then - if you are interested - feel free to add a PR with an optional flag to save method to filter out optimizer's variables.