tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Failed to model.save when use tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper #2675

Open yang-stressfree opened 2 years ago

yang-stressfree commented 2 years ago

System information

Describe the bug

If my decoder model include tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers, the model will failed to save;

Otherwise, if remove tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers from the model, the model can save successfully.

Code to reproduce the issue

I create a colab note to reproduce this issue:

https://colab.research.google.com/gist/yang-stressfree/bc78b67ca6f051fe60a7e863b99cc1b3#scrollTo=0jcDbzxVD4-h

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as keras

class TinyDemoModel(keras.Model):
    def __init__(self, rnn_units, *args, **kwargs):
        super(TinyDemoModel, self).__init__(*args, **kwargs)
        self.vocab_size = 10
        self.rnn_units = rnn_units
        self.dec_embedding = keras.layers.Embedding(input_dim=self.vocab_size, output_dim=self.vocab_size)
        self.dec_lstm_cell = keras.layers.LSTMCell(units=self.rnn_units)
        self.dec_attention = tfa.seq2seq.BahdanauAttention(units=self.rnn_units)
        self.dec_rnn_cell = tfa.seq2seq.AttentionWrapper(cell=self.dec_lstm_cell,
                                                         attention_mechanism=self.dec_attention,
                                                         attention_layer_size=self.rnn_units)
        self.dec_fc = keras.layers.Dense(self.vocab_size)
        self.dec_train = tfa.seq2seq.BasicDecoder(self.dec_rnn_cell, tfa.seq2seq.sampler.TrainingSampler(),
                                                  output_layer=self.dec_fc)

    def dec_build_initial_state(self, batch_size, enc_h, enc_c):
        initial_state = self.dec_rnn_cell.get_initial_state(batch_size=batch_size, dtype=enc_h.dtype)
        initial_state = initial_state.clone(cell_state=[enc_h, enc_c])
        return initial_state

    def get_config(self):
        raise NotImplementedError

    def call(self, inputs, training=None, mask=None):
        batch_seq_encoded, batch_seq_labeled = inputs
        shape_batch_seq_labeled = tf.shape(batch_seq_labeled)
        batch_size = shape_batch_seq_labeled[0]
        self.dec_attention.setup_memory(batch_seq_encoded)
        initial_state = self.dec_build_initial_state(batch_size, tf.zeros([batch_size, self.rnn_units]),
                                                     tf.zeros([batch_size, self.rnn_units]))
        batch_seq_labeled_embedded = self.dec_embedding(batch_seq_labeled)
        output, _, _ = self.dec_train(batch_seq_labeled_embedded, initial_state=initial_state)
        # pad and return
        batch_seq_predicted_odds = output.rnn_output
        pad_size = shape_batch_seq_labeled[1] - tf.shape(batch_seq_predicted_odds)[1]
        batch_seq_predicted_odds = tf.pad(batch_seq_predicted_odds, [[0, 0], [0, pad_size], [0, 0]])
        return batch_seq_predicted_odds

def train_and_save():
    rnn_units = 8
    batch_size = 2
    batch_seq_encoded = tf.ones([batch_size, 100, rnn_units], dtype=tf.float32)
    batch_seq_labeled = tf.ones([batch_size, 100], dtype=tf.int32)
    model = TinyDemoModel(rnn_units)
    loss_obj = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer="adam", loss=loss_obj)
    model.fit(x=(batch_seq_encoded, batch_seq_labeled), y=batch_seq_labeled)
    model.save(filepath="/tmp/saved_model_tiny_demo_model")

if __name__ == "__main__":
    train_and_save()

Run this file, will get:

Traceback (most recent call last):
  File "tiny_model.py", line 61, in <module>
    train_and_save()
  File "tiny_model.py", line 57, in train_and_save
    model.save(filepath="/tmp/saved_model_tiny_demo_model")
  File "lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "lib/python3.9/site-packages/tensorflow/python/saved_model/save.py", line 402, in map_resources
    raise ValueError(
ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with `tf.get_static_value`.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

yang-stressfree commented 2 years ago

Similar issue https://github.com/tensorflow/addons/issues/2672

yang-stressfree commented 2 years ago

Test with:

tensorflow==2.8.0
tensorflow-addons==0.16.1

get the same error 😢

yang-stressfree commented 2 years ago

I figure out why this error occur when try to save the model with tf.saved_model.save or model.save:

    def call(self, inputs, training=None, mask=None):
        batch_seq_encoded, batch_seq_labeled = inputs
        # ...
        self.dec_attention.setup_memory(batch_seq_encoded)

actually the error explain itself:

ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with tf.get_static_value.

Because tf.saved_model.save try to convert batch_seq_encoded which is a dynamic value as memory of attention mechanism to a static value.

For anyone want to export model by tf.saved_model.save, DO NOT implement attention mechanism with tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper.

As a beginner, https://www.tensorflow.org/text/tutorials/nmt_with_attention is a workable and portable solution.