tensorflow / models

Models and examples built with TensorFlow
Other
77.01k stars 45.78k forks source link

My Distributed training code gain slower speed sec/batch than cifar10_train.py #4589

Closed Told closed 6 years ago

Told commented 6 years ago

Please go to Stack Overflow for help and support:

http://stackoverflow.com/questions/tagged/tensorflow

Also, please understand that many of the models included in this repository are experimental and research-style code. If you open a GitHub issue, here is our policy:

  1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
  2. The form below must be filled out.

Here's why we have that policy: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.


System information

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with

python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

Describe the problem

Based on cifar10_train.py and tensorflow distributed doc, I writed a cifar10_train_distributed.py in distributed version and ran on a k8s with 2 worker and 1 ps. But slower speed displayed. cifar10_train in 1 GPU (Tesla p 100), gained 0.008 sec/batch. cifar10_train_distributed gained 0.027 sec/batch. here is my code . I do not understand!!! Help!

Source code / logs

def train(): print("here") tf_config_json = FLAGS.tf_config tf_config = json.loads(tf_config_json)

get cluster info and build spec object that used to init each node

cluster_spec = tf_config.get("cluster", {})
cluster_spec_object = tf.train.ClusterSpec(cluster_spec)
#get current task
task = tf_config.get("task", {})
job_name = task['type']
job_index = task['index']
#tf server definition
server_def = tf.train.ServerDef(
    cluster=cluster_spec_object.as_cluster_def(),
    protocol="grpc",
    job_name=job_name,
    task_index=job_index)
#init cluster
#cluster = tf.train.ClusterSpec(cluster_spec)
print(cluster_spec, task)
print("hello")
server = tf.train.Server(server_def)
is_chief = (job_name == "master")
if 'ps' == job_name:
    print("ps join..\n")
    server.join()
worker_device = "/job:%s/task:%d" % (job_name, job_index)
print(worker_device)
with tf.device(tf.train.replica_device_setter(
        worker_device=worker_device, cluster=cluster_spec_object)):
        """Train CIFAR-10 for a number of steps."""
    global_step = tf.train.get_or_create_global_step()
    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device("/cpu:0"):
        images, labels = cifar10.distorted_inputs()
        # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)
    # Calculate loss.
    loss = cifar10.loss(logits, labels)
    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)
with tf.train.MonitoredTrainingSession(
    master=server.target,
    is_chief=is_chief,
    checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
    tf.train.NanTensorHook(loss),],
    config=tf.ConfigProto(
         log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    start_train_time = time.time()
    count_step = 0
    while not mon_sess.should_stop():
        step_start_time = time.time()
        _, loss_value = mon_sess.run([train_op, loss])
        step_end_time = time.time()
        if count_step % FLAGS.log_frequency == 0:
            duration = step_end_time - step_start_time
            examples_per_sec =  FLAGS.batch_size / duration
            format_str = ('%s: step %d %s, loss = %.2f (%.1f examples/sec; %.3f '
                'sec/batch)')
            print (format_str % (datetime.now(),count_step, job_name, loss_value,
                examples_per_sec, duration))
        count_step += 1
end_train_time = time.time()
cost_train_time = end_train_time - start_train_time
print("training cost : %d" % cost_train_time)
bitfort commented 6 years ago

I will recommend you touch base with http://stackoverflow.com/questions/tagged/tensorflow , this is a great starting point to get going with debugging this issue -- feel free to post here if you discover a specific problem.