lifeomic / sparkflow

Easy to use library to bring Tensorflow on Apache Spark
MIT License
298 stars 46 forks source link

Incompatibility of build_graph function with tensorflow 2.0 #37

Open mattdornfeld opened 4 years ago

mattdornfeld commented 4 years ago

I'm running the below example in tf 2.0 and I get an error AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'export_meta_graph'. It seems the build_graph function is incompatible with the tf 2.0 API. The example works fine with tf 1.15.

In [2]: import tensorflow as tf
   ...: from tensorflow import keras
   ...: from tensorflow.keras import layers
   ...: from sparkflow.graph_utils import build_graph
   ...:
   ...: tf.compat.v1.disable_eager_execution()
   ...:
   ...: output_dim = 64
   ...: model = keras.Sequential()
   ...: model.add(layers.Dense(output_dim, kernel_initializer='uniform', input_shape=(10,)))
   ...: model.add(layers.Activation('softmax'))
   ...:
   ...: loss_fn = keras.losses.SparseCategoricalCrossentropy()
   ...: model.compile(loss=loss_fn, optimizer='adam')
   ...:
   ...: y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, output_dim))
   ...: loss = model.loss.fn(y_true, model.output)
   ...: mg = build_graph(lambda : loss)
   ...:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-33c7a76593c9> in <module>
     16 y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, output_dim))
     17 loss = model.loss.fn(y_true, model.output)
---> 18 mg = build_graph(lambda : loss)

/usr/local/lib/python3.7/site-packages/sparkflow/graph_utils.py in build_graph(func)
     12     with first_graph.as_default() as g:
     13         v = func()
---> 14         mg = json_format.MessageToJson(tf.train.export_meta_graph())
     15     return mg
     16

AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'export_meta_graph'
M0315G commented 3 years ago

Yes I experienced the same issue when training my CNN classifier. The reason behind it is that from TF 2.x, Tensorflow supports eager version and does not depend on the DAGs heavily to run a session. The only solution I could find is to downgrade to TF 1.x and use the API