intel-analytics / analytics-zoo

Distributed Tensorflow, Keras and PyTorch on Apache Spark/Flink & Ray
https://analytics-zoo.readthedocs.io/
Apache License 2.0
16 stars 3 forks source link

RNNs with stateful=True not yet supported with tf.distribute.Strategy. #520

Open leonardozcm opened 3 years ago

leonardozcm commented 3 years ago

tensorflow==2.3.0 When fiting lstm with rnn set stateful=True using Orca, it seems that distrubuted training is not supported yet by origin tf.distribute?

Traceback (most recent call last):
  File "/home/arda/PycharmProjects/forecasting_with_lstm/lstm_orca", line 154, in <module>
    train()
  File "/home/arda/PycharmProjects/forecasting_with_lstm/lstm_orca", line 105, in train
    est = Estimator.from_keras(model_creator=model_creator)
  File "/home/arda/Project/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_ray_estimator.py", line 191, in from_keras
    backend=backend, compile_args_creator=compile_args_creator)
  File "/home/arda/Project/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_ray_estimator.py", line 163, in __init__
    for i, worker in enumerate(self.remote_workers)])
  File "/home/arda/anaconda3/envs/tf2/lib/python3.6/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): ray::TFRunner.setup_distributed() (pid=32403, ip=10.239.166.112)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/arda/Project/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 315, in setup_distributed
    self.model = self.model_creator(self.config)
  File "/home/arda/PycharmProjects/forecasting_with_lstm/lstm_orca", line 90, in model_creator
    batch_input_shape=[1, None, 1]),
  File "/home/arda/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent_v2.py", line 1082, in __init__
    **kwargs)
  File "/home/arda/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py", line 1099, in __init__
    super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
  File "/home/arda/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py", line 2753, in __init__
    **kwargs)
  File "/home/arda/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py", line 443, in __init__
    raise ValueError('RNNs with stateful=True not yet supported with '
ValueError: RNNs with stateful=True not yet supported with tf.distribute.Strategy.
Stopping orca context

Process finished with exit code 1
cyita commented 3 years ago

@yangw1234 Please take a look.

yangw1234 commented 3 years ago

this is an issue of tensorflow. I'll close it.

jason-dai commented 3 years ago

this is an issue of tensorflow. I'll close it.

@yangw1234 what's your suggestion then? using TF 1.5? Using Horovod?