tensorflow / tpu

Reference models and tools for Cloud TPUs.
https://cloud.google.com/tpu/
Apache License 2.0
5.21k stars 1.77k forks source link

iterations_per_loop on inference #726

Open piercefreeman opened 4 years ago

piercefreeman commented 4 years ago

The documentation that I've read on TPU training points to the importance of having data readily accessible, so we can reach full utilization on the chipset.

Looking at the mnist_tpu.py training pipeline, it queues 50 batches onto the TPU via the iterations_per_loop variable (defined within tpu estimator).

tf.flags.DEFINE_integer("iterations", 50,
                        "Number of iterations per TPU training loop.")

...

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tf.estimator.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
  )

...

estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
I0321 22:04:31.084316 140264358516544 tpu_estimator.py:2307] global_step/sec: 255.509
INFO:tensorflow:examples/sec: 261641
I0321 22:04:31.084694 140264358516544 tpu_estimator.py:2308] examples/sec: 261641
INFO:tensorflow:Enqueue next (50) batch(es) of data to infeed.
I0321 22:04:31.085694 140264358516544 tpu_estimator.py:600] Enqueue next (50) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (50) batch(es) of data from outfeed.
I0321 22:04:31.085918 140264358516544 tpu_estimator.py:604] Dequeue next (50) batch(es) of data from outfeed.
INFO:tensorflow:global_step/sec: 257.976
I0321 22:04:31.278119 140264358516544 tpu_estimator.py:2307] global_step/sec: 257.976

During prediction, however, it seems to ignore the iterations_per_loop parameter and only queues a single batch into the TPU.

predictions = estimator.predict(input_fn=predict_input_fn)
I0321 22:05:06.856015 140264358516544 tpu_estimator.py:600] Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
I0321 22:05:06.856195 140264358516544 tpu_estimator.py:604] Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.

This seems like a potential performance bottleneck when doing inference over several thousand batches. Am I correctly interpreting the code path here? And if so, is there any work around to pre-load multiple batches during prediction time?

gagika commented 4 years ago

predict() API processes one batch at a time. You can choose larger batch size when possible. For several thousand batches please use evaluate() API.