cerndb / dist-keras

Distributed Deep Learning, with a focus on distributed training, using Keras and Apache Spark.
http://joerihermans.com/work/distributed-keras/
GNU General Public License v3.0
623 stars 169 forks source link

Tensorflow back end not working on GPUs (session lost) #44

Open johneortega opened 6 years ago

johneortega commented 6 years ago

When running the mnist example on gpu with tensorflow it seems to lose the session while trying to train a model. It executes fine the first part of the code that uses a session; so, I'm perplexed as to what could be going on here. It seems as if the tf session is being lost on the new worker. Any help would be greatly appreciated.

Traceback (most recent call last): File "mnist.py", line 268, in trained_model = trainer.train(training_set) File "lib/python2.7/site-packages/distkeras/trainers.py", line 638, in train self.history = dataset.rdd.mapPartitionsWithIndex(worker.train).collect() File "spark/python/pyspark/rdd.py", line 808, in collect port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) File "spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in call File "spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 19.0 failed 4 times, most recent failure: Lost task 1.3 in stage 19.0 (TID 60, 10.30.72.126, executor 0): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "spark/python/lib/pyspark.zip/pyspark/worker.py", line 174, in main process() File "spark/python/lib/pyspark.zip/pyspark/worker.py", line 169, in process serializer.dump_stream(func(split_index, iterator), outfile) File "lib/python2.7/site-packages/distkeras/workers.py", line 260, in train self.prepare_model() File "lib/python2.7/site-packages/distkeras/workers.py", line 97, in prepare_model self.model = deserialize_keras_model(self.model) File "lib/python2.7/site-packages/distkeras/utils.py", line 126, in deserialize_keras_model model.set_weights(weights) File lib/python2.7/site-packages/keras/models.py", line 702, in set_weights self.model.set_weights(weights) File "lib/python2.7/site-packages/keras/engine/topology.py", line 2004, in set_weights K.batch_set_value(tuples) File "lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2193, in batch_set_value get_session().run(assign_ops, feed_dict=feed_dict) File "lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 163, in get_session _SESSION = tf.Session(config=config) File "lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1486, in init super(Session, self).init(target, graph, config=config) File "lib/python2.7/site-packages/tensorflow/python/client/session.py", line 621, in init self._session = tf_session.TF_NewDeprecatedSession(opts, status) File "lib/python2.7/contextlib.py", line 24, in exit self.gen.next() File "lib/python2.7/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)