My Distributed training code gain slower speed sec/batch than cifar10_train.py #4589

Closed Told closed 6 years ago

Told commented 6 years ago

Describe the problem

Based on cifar10_train.py and tensorflow distributed doc, I writed a cifar10_train_distributed.py in distributed version and ran on a k8s with 2 worker and 1 ps. But slower speed displayed. cifar10_train in 1 GPU (Tesla p 100), gained 0.008 sec/batch. cifar10_train_distributed gained 0.027 sec/batch. here is my code . I do not understand!!! Help!

Source code / logs

def train(): print("here") tf_config_json = FLAGS.tf_config tf_config = json.loads(tf_config_json)

get cluster info and build spec object that used to init each node

cluster_spec = tf_config.get("cluster", {})
cluster_spec_object = tf.train.ClusterSpec(cluster_spec)
#get current task
task = tf_config.get("task", {})
job_name = task['type']
job_index = task['index']
#tf server definition
server_def = tf.train.ServerDef(
#init cluster
#cluster = tf.train.ClusterSpec(cluster_spec)
print(cluster_spec, task)
server = tf.train.Server(server_def)
is_chief = (job_name == "master")
if 'ps' == job_name:
    print("ps join..\n")
worker_device = "/job:%s/task:%d" % (job_name, job_index)
with tf.device(tf.train.replica_device_setter(
        worker_device=worker_device, cluster=cluster_spec_object)):
        """Train CIFAR-10 for a number of steps."""
    global_step = tf.train.get_or_create_global_step()
    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device("/cpu:0"):
        images, labels = cifar10.distorted_inputs()
        # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)
    # Calculate loss.
    loss = cifar10.loss(logits, labels)
    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)
with tf.train.MonitoredTrainingSession(
         log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    start_train_time = time.time()
    count_step = 0
    while not mon_sess.should_stop():
        step_start_time = time.time()
        _, loss_value = mon_sess.run([train_op, loss])
        step_end_time = time.time()
        if count_step % FLAGS.log_frequency == 0:
            duration = step_end_time - step_start_time
            examples_per_sec =  FLAGS.batch_size / duration
            format_str = ('%s: step %d %s, loss = %.2f (%.1f examples/sec; %.3f '
            print (format_str % (datetime.now(),count_step, job_name, loss_value,
                examples_per_sec, duration))
        count_step += 1
end_train_time = time.time()
cost_train_time = end_train_time - start_train_time
print("training cost : %d" % cost_train_time)
