Open yang-stressfree opened 2 years ago
Similar issue https://github.com/tensorflow/addons/issues/2672
Test with:
tensorflow==2.8.0
tensorflow-addons==0.16.1
get the same error 😢
I figure out why this error occur when try to save the model with tf.saved_model.save
or model.save
:
def call(self, inputs, training=None, mask=None):
batch_seq_encoded, batch_seq_labeled = inputs
# ...
self.dec_attention.setup_memory(batch_seq_encoded)
actually the error explain itself:
ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with
tf.get_static_value
.
Because tf.saved_model.save
try to convert batch_seq_encoded
which is a dynamic value as memory of attention mechanism to a static value.
For anyone want to export model by tf.saved_model.save
, DO NOT implement attention mechanism with tfa.seq2seq.BahdanauAttention
and tfa.seq2seq.AttentionWrapper
.
As a beginner, https://www.tensorflow.org/text/tutorials/nmt_with_attention is a workable and portable solution.
System information
Describe the bug
If my decoder model include tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers, the model will failed to save;
Otherwise, if remove tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers from the model, the model can save successfully.
Code to reproduce the issue
I create a colab note to reproduce this issue:
https://colab.research.google.com/gist/yang-stressfree/bc78b67ca6f051fe60a7e863b99cc1b3#scrollTo=0jcDbzxVD4-h
Run this file, will get:
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.