tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
Apache License 2.0
186.32k stars 74.3k forks source link

tf.train.exponential_decay examples use 32 bit number for batch #2859

Closed jwise closed 8 years ago

jwise commented 8 years ago

As far as I can tell, tf.train.exponential_decay examples seem to use a 32-bit signed number for global_step, because the batch number is also a tf.int32. This means that long runs can result in unpleasant surprises with learning rates, which result in a frustrating experience for new users. [1] Examples should be updated to initialize batch as a dtype=tf.int64.

I originally believed this to be a tf.train.exponential_decay bug, and wrote some code to minimize the issue. I now understand that this comes from the variable, but you can have the below repro case anyway, because I think it makes it a little more obvious as to the kind of thing that can happen. When you run it, you'll note a discontinuity between epoch 85 and 86 in learning rate that comes from batch * batch_size overflowing (and results in losing an evening's training...).

import tensorflow as tf

def main(argv=None):
  epoch_size = 25000000
  epoches = 100
  batch_size = 16384

  batch = tf.Variable(0) # <-- this line, in examples, should be changed to have dtype=tf.int64, with explanation of why
  learning_rate = tf.train.exponential_decay(
      batch * batch_size,
      staircase = True)
  loss = tf.Variable(0.0)
  eval_frequency = int(epoch_size / batch_size / 2)

  optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) \
     .minimize(loss, global_step = batch)

  with tf.Session() as sess:

    for step in xrange(int(epoches * epoch_size) // batch_size):
      _, b, lr = sess.run([optimizer, batch, learning_rate])
      if step % eval_frequency == 0:
        print("Step %d (epoch %.2f): batch %d, learning rate %.6f" %
            (step, float(step) * batch_size / epoch_size, b, lr))

if __name__ == '__main__':

[1] No, I did not checkpoint midway through. Yes, I have learned my lesson.

sherrym commented 8 years ago

Examples are supposed to cover common use cases. If you do expect your experiments to overflow a 32-bit integer, then yes, please explicitly set it to int64.

jwise commented 8 years ago

Hmm, this might be an indication that what I'm doing is unreasonable :-) But my understanding is that "more data is better", and that 12 hour training runs are not super uncommon in the world of DNNs. (My inexperience may show through here!) I had actually (obviously, incorrectly) assumed that the type of that Variable would have been some kind of floating point datatype, given that it wasn't specified in the documentation.

Given how much detail the rest of the (wonderful!) documentation goes into about gotchas and issues that one might have, I think I respectfully disagree: I tripped on this relatively readily (after only a week or so of using the tool), and it cost me 10 hours or so of training time and a day's worth of results, and I think that highlighting this could help new users to think through what exactly is going on in the examples.

But, it's your project! Either way, thanks for the consideration.