tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Errors when using tf.keras.Input & tfa.seq2seq in eager mode #1898

Open hccho2 opened 4 years ago

hccho2 commented 4 years ago

System information

batch_size = 2

encoder_hidden_dim = 6
encoder_output = tf.keras.Input(shape=[None,encoder_hidden_dim],batch_size=batch_size) # Input
encoder_seq_length = tf.keras.Input(shape=[],batch_size=batch_size, dtype=tf.int32) # Input

decoder_vocab_size = 10
decoder_embedding_dim = 8
decoder_hidden_dim = 5
attention_units = 11
output_dim = 5

decoder_cell = tf.keras.layers.LSTMCell(decoder_hidden_dim)
decoder_init_state = tuple(decoder_cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=tf.float32))
attention_mechanism = tfa.seq2seq.BahdanauAttention(attention_units, encoder_output,memory_sequence_length=encoder_seq_length)
attention_wrapper_cell = tfa.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,attention_layer_size=13,initial_cell_state=decoder_init_state,output_attention=True,alignment_history=False)
projection_layer = tf.keras.layers.Dense(output_dim)

attention_init_state = attention_wrapper_cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=tf.float32)

attention_init_state = tfa.seq2seq.AttentionWrapperState(list(attention_init_state.cell_state),attention_init_state.attention,attention_init_state.alignments,

sampler = tfa.seq2seq.sampler.TrainingSampler()
decoder = tfa.seq2seq.BasicDecoder(attention_wrapper_cell, sampler, output_layer=projection_layer)

decoder_inputs = tf.keras.Input(shape=[None],batch_size=batch_size, dtype=tf.int32)   # Input
decoder_seq_length = tf.keras.Input(shape=[],batch_size=batch_size, dtype=tf.int32)  # Input

decoder_embedding = tf.keras.layers.Embedding(decoder_vocab_size, decoder_embedding_dim,trainable=True) 
decoder_embedded = decoder_embedding(decoder_inputs)

outputs, last_state, last_sequence_lengths = decoder(decoder_embedded,initial_state=attention_init_state, sequence_length=decoder_seq_length,training=True)

my_model = tf.keras.Model([decoder_inputs,decoder_seq_length,encoder_output,encoder_seq_length],[outputs, last_state, last_sequence_lengths])

### Test
encoder_timestep = 10
decoder_timestep = 12

encoder_output_data = tf.random.normal(shape=(batch_size, encoder_timestep, encoder_hidden_dim))
encoder_seq_length_data = tf.convert_to_tensor([encoder_timestep]*batch_size,dtype=tf.int32)

decoder_inputs_data = tf.random.uniform([batch_size,decoder_timestep], 0,decoder_vocab_size,tf.int32)
decoder_seq_length_data = tf.convert_to_tensor([decoder_timestep]*batch_size,dtype=tf.int32)

a,b,c = my_model([decoder_inputs_data,decoder_seq_length_data,encoder_output_data,encoder_seq_length_data])  # errors !!!

Q1. What is the cause of the error?

If set memory_sequence_length=None ---> No Error,

attention_mechanism = tfa.seq2seq.BahdanauAttention(attention_units, encoder_output,memory_sequence_length=None)

Q2. a,b,c are non-numeric tensors. Why are there no numerical values?

failure-to-thrive commented 4 years ago

Due to cross-interactions between Graph tensors and Python code flow deep inside TensorFlow, as a quick workaround you have to switch back into the Graph mode with tf.compat.v1.disable_eager_execution().

hccho2 commented 4 years ago

@failure-to-thrive Are cross-interactions my fault? Or is it a bug?

failure-to-thrive commented 4 years ago

@hccho2 Seems not your.

bhack commented 4 years ago

class BahdanauAttention(_BaseAttentionMechanism):

And in_BaseAttentionMechanism:

Also note that this layer does not work with Keras model when model.compile(run_eagerly=True) due to the fact that this layer is stateful. The support for that will be added in a future version.

/cc @qlzh727

bhack commented 4 years ago

I don't know if in your case could be resolved with manual memory reset in the PR introduced fixing https://github.com/tensorflow/addons/issues/535

guillaumekln commented 4 years ago

This is a known issue when the Keras functional API, stateful layers, and eager mode are used at the same time.

The Keras functional API creates symbolic tensors that are saved in stateful layers, but the layer is executed in eager mode where the TensorFlow runtime does not expect to find symbolic tensors.

jkquijas commented 3 years ago

Has this been resolved?

guillaumekln commented 3 years ago

No, the issue is still open.