tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
186.57k stars 74.33k forks source link

Stuck at prepare_or_wait_for_session for workers when running on kubernetes #3037

Closed perhapszzy closed 8 years ago

perhapszzy commented 8 years ago

Environment info

Operating System: Kubernetes/Ubuntu 14.04

Steps to reproduce

  1. start tensorflow server on kubernetes using yaml generated by the python code
  2. Start tensorflow code:
python mnist_dnn.py --worker_grpc_url=grpc://180.101.191.78:30001 --worker_index=0 --workers=180.101.191.78:30001,180.101.191.78:30002,180.101.191.78:30003 --parameter_servers=tf-ps0:2222,tf-ps1:2222
python mnist_dnn.py --worker_grpc_url=grpc://180.101.191.78:30002 --worker_index=1 --workers=180.101.191.78:30001,180.101.191.78:30002,180.101.191.78:30003 --parameter_servers=tf-ps0:2222,tf-ps1:2222

What have you tried?

  1. The worker with index 0 (chief) can execute normally.
  2. It was able to execute well (using the same yaml and code)
  3. I tried to restart the servers, but it didn't work.
  4. All other workers stuck at prepare_or_wait_for_session. However, it seems that logs suggest other workers are actually executing.

Log is here and the code is here:

import sys
import time

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import datetime

flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/data",
                    "Directory for storing mnist data")

flags.DEFINE_boolean("download_only", False,
                     "Only perform downloading of data; Do not proceed to "
                     "session preparation, model definition or training")

flags.DEFINE_integer("worker_index", 0,
                     "Worker task index, should be >= 0. worker_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")

flags.DEFINE_string("workers", None,
                    "The worker url list, separated by comma (e.g. tf-worker1:2222,1.2.3.4:2222)")

flags.DEFINE_string("parameter_servers", None,
                    "The ps url list, separated by comma (e.g. tf-ps2:2222,1.2.3.5:2222)")

flags.DEFINE_integer("grpc_port", 2222,
                     "TensorFlow GRPC port")

flags.DEFINE_integer("train_steps", 200000,
                     "Number of (global) training steps to perform")

flags.DEFINE_string("worker_grpc_url", None,
                    "Worker GRPC URL (e.g., grpc://1.2.3.4:2222, or "
                    "grpc://tf-worker0:2222)")
FLAGS = flags.FLAGS

cur_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

def nn_layer(input_tensor, input_dim, output_dim, act=tf.nn.relu):
    with tf.name_scope(cur_time):
        weights = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1))
        biases = tf.Variable(tf.constant(0.1, shape=[output_dim]))
    activations = act(tf.matmul(input_tensor, weights) + biases)
    return activations

def model(x, y_, global_step):
    hidden_nodes = 500
    hidden1 = nn_layer(x, 784, hidden_nodes)
    y = nn_layer(hidden1, hidden_nodes, 10, act=tf.nn.softmax)

    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy, global_step=global_step)

    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    return train_step, accuracy

print("Loading data from worker index = %d" % FLAGS.worker_index)

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
print("Testing set size: %d" % len(mnist.test.images))
print("Training set size: %d" % len(mnist.train.images))
if FLAGS.download_only: sys.exit(0)

print("Worker GRPC URL: %s" % FLAGS.worker_grpc_url)
print("Workers = %s" % FLAGS.workers)
print("Using time = %s" % cur_time)

is_chief = (FLAGS.worker_index == 0)
cluster = tf.train.ClusterSpec({"ps": FLAGS.parameter_servers.split(","), "worker": FLAGS.workers.split(",")})
# Construct device setter object
device_setter = tf.train.replica_device_setter(cluster=cluster)

# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
with tf.device(device_setter):
    with tf.name_scope(cur_time):
        global_step = tf.Variable(0, trainable=False)

    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    val_feed = {x: mnist.test.images, y_: mnist.test.labels}

    train_step, accuracy = model(x, y_, global_step)

    sv = tf.train.Supervisor(is_chief=is_chief,
                             logdir="/tmp/dist-mnist-log/train",
                             saver=tf.train.Saver(),
                             init_op=tf.initialize_all_variables(),
                             recovery_wait_secs=1,
                             global_step=global_step)
    sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True,
                                 device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.worker_index])

    # The chief worker (worker_index==0) session will prepare the session,
    # while the remaining workers will wait for the preparation to complete.
    if is_chief:
        print("Worker %d: Initializing session..." % FLAGS.worker_index)
    else:
        print("Worker %d: Waiting for session to be initialized..." % FLAGS.worker_index)

    with sv.prepare_or_wait_for_session(FLAGS.worker_grpc_url, config=sess_config) as sess:
        print("Worker %d: Session initialization complete." % FLAGS.worker_index)

        # Perform training
        time_begin = time.time()
        print("Training begins @ %f" % time_begin)

        local_step = 0
        while True:
            # Training feed
            batch_xs, batch_ys = mnist.train.next_batch(100)
            train_feed = {x: batch_xs, y_: batch_ys}

            _, step = sess.run([train_step, global_step], feed_dict=train_feed)
            local_step += 1
            if local_step % 100 == 0:
                print("Worker %d: training step %d done (global step: %d); Accuracy: %g" % 
                      (FLAGS.worker_index, local_step, step, sess.run(accuracy, feed_dict=val_feed)))
            if step >= FLAGS.train_steps: break

        time_end = time.time()
        print("Training ends @ %f" % time_end)
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)

        # Accuracy on test data
        print("Final test accuracy = %g" % (sess.run(accuracy, feed_dict=val_feed)))
perhapszzy commented 8 years ago

by the way, some times (not always). some worker may report

E0625 12:21:16.511111376    3825 tcp_client_posix.c:191]     failed to connect to 'ipv4:180.101.191.78:30002': timeout occurred

but if it failed to connect the server, why the server log still shows the log above? And most of the workers (or most of the time) the code just stuck at the prepare_or_wait_for_session step and output nothing,

perhapszzy commented 8 years ago

Seems to be duplicate of #2472