keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.75k stars 2.03k forks source link

Would it be possible to make lstm_seq2seq support mixed precision? #1861

Open joshuayao opened 4 months ago

joshuayao commented 4 months ago

Issue Type

Bug

Source

binary

Keras Version

2.16.0

Custom Code

No

OS Platform and Distribution

No response

Python version

3.11

GPU model and memory

No response

Current Behavior?

lstm_seq2seq.py works well with the default fp32 data type when using the legacy keras. import os os.environ["TF_USE_LEGACY_KERAS"] = "1"

Training completed with mixed precision successfully, but inference failed: Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'. import tensorflow as tf tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Standalone code to reproduce the issue or tutorial link

Just add the following code snippet at the beginning of this code example https://github.com/keras-team/keras-io/blob/master/examples/nlp/lstm_seq2seq.py.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Relevant log output

Traceback (most recent call last):
  File "examples/nlp/lstm_seq2seq.py", line 332, in <module>
    decoded_sentence = decode_sequence(input_seq)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "examples/nlp/lstm_seq2seq.py", line 297, in decode_sequence
    output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_filehtao8ahn.py", line 15, in tf__predict_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileg32txcku.py", line 45, in tf__step_function
    outputs = ag__.converted_call(ag__.ld(model).distribute_strategy.run, (ag__.ld(run_step),), dict(args=(ag__.ld(data),)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileg32txcku.py", line 18, in run_step
    outputs = ag__.converted_call(ag__.ld(model).predict_step, (ag__.ld(data),), None, fscope_1)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file8lg3jru0.py", line 32, in tf__predict_step
    retval_ = ag__.converted_call(ag__.ld(self), (ag__.ld(x),), dict(training=False), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filetingiv6p.py", line 67, in tf____call__
    retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
  File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
    outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 51, in error_handler
    ag__.if_stmt(ag__.converted_call(ag__.ld(hasattr), (ag__.ld(e), '_keras_call_info_injected'), None, fscope_1), if_body, else_body, get_state, set_state, (), 0)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 47, in if_body
    raise ag__.ld(e)
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
    retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqmtratyp.py", line 29, in tf__call
    retval_ = ag__.converted_call(ag__.ld(self)._run_internal_graph, (ag__.ld(inputs),), dict(training=ag__.ld(training), mask=ag__.ld(mask)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 174, in tf___run_internal_graph
    ag__.for_stmt(ag__.ld(depth_keys), None, loop_body_4, get_state_9, set_state_9, (), {'iterate_names': 'depth'})
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 166, in loop_body_4
    ag__.for_stmt(ag__.ld(nodes), None, loop_body_3, get_state_8, set_state_8, (), {'iterate_names': 'node'})
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 165, in loop_body_3
    ag__.if_stmt(ag__.not_(continue__3), if_body_4, else_body_4, get_state_7, set_state_7, ('continue__3',), 0)
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 160, in if_body_4
    ag__.if_stmt(ag__.not_(continue__3), if_body_3, else_body_3, get_state_6, set_state_6, (), 0)
  File "/tmp/__autograph_generated_file0lcomwvu.py", line 145, in if_body_3
    outputs = ag__.converted_call(ag__.ld(node).layer, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 184, in tf____call__
    ag__.if_stmt(ag__.and_(lambda: ag__.ld(initial_state) is None, lambda: ag__.ld(constants) is None), if_body_7, else_body_7, get_state_8, set_state_8, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self._num_constants', 'self.constants_spec', 'self.input_spec', 'self.state_spec'), 8)
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 175, in else_body_7
    ag__.if_stmt(ag__.ld(is_keras_tensor), if_body_6, else_body_6, get_state_7, set_state_7, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self.input_spec'), 5)
  File "/tmp/__autograph_generated_file9p2pb1je.py", line 168, in else_body_6
    retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, (ag__.ld(inputs),), dict(**ag__.ld(kwargs)), fscope)
    ^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
    ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
    raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
    retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
  File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
    outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler
    raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
    retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call
    ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
    ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
    ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
    last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection
    ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
  File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
    last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: in user code:

    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2436, in predict_function  *
        return step_function(self, iterator)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2409, in run_step  *
        outputs = model.predict_step(data)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2377, in predict_step  *
        return self(x, training=False)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler  *
        del filtered_tb
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 588, in __call__  *
        return super().__call__(*args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler  *
        del filtered_tb
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__  *
        outputs = call_fn(inputs, *args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 514, in call  *
        return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 671, in _run_internal_graph  *
        outputs = node.layer(*args, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/base_rnn.py", line 627, in __call__  *
        return super().__call__(inputs, **kwargs)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 560, in error_handler  *
        filtered_tb = _process_traceback_frames(e.__traceback__)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__  *
        outputs = call_fn(inputs, *args, **kwargs)
    File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler  **
        raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
    File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
        retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call  **
        ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
        ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
        last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
    File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection  **
        ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
    File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
        last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 983, in standard_lstm
        last_output, outputs, new_states = backend.rnn(
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/backend.py", line 4985, in rnn
        output_time_zero, _ = step_function(
    File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 970, in step
        z += backend.dot(h_tm1, recurrent_kernel)

    TypeError: Exception encountered when calling layer 'lstm_1' (type LSTM).

    Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'.

    Call arguments received by layer 'lstm_1' (type LSTM):
      • inputs=tf.Tensor(shape=(None, 1, 91), dtype=bfloat16)
      • mask=None
      • training=False
      • initial_state=['tf.Tensor(shape=(None, 256), dtype=float32)', 'tf.Tensor(shape=(None, 256), dtype=float32)']
grasskin commented 4 months ago

Hi @joshuayao - just to confirm, this does not affect current Keras (Keras3)? Would you be able to reproduce without the legacy flag but using mixed precision?