keithito / tacotron

A TensorFlow implementation of Google's Tacotron speech synthesis with pre-trained model (unofficial)
MIT License
2.96k stars 959 forks source link

How can the output dim of ConcatOutputAndAttentionWrapper be 512 on Tacotron2 branch #168

Closed begeekmyfriend closed 6 years ago

begeekmyfriend commented 6 years ago

In tacotron.py we can see the attention_depth is set as 128 and then it is wrapped into ConcatOutputAndAttentionWrapper in which the attention RNN output and the context will be concatenated together. But I guess the concatenated dim should be 256 instead of 512.

syang1993 commented 6 years ago

In his implementation, ConcatOutputAndAttentionWrapper function concatenates the cell_output and context vector of AttentionWrapper(cell, attention_mechanism, ...). In tensorflow, the dim of cell_output is the unit number of cell, and the dim of context vector is the dim of memory for attention_mechanism . For example, in his implementation:

https://github.com/keithito/tacotron/blob/master/models/tacotron.py#L51-L55

the dim of cell_output is hp.attention_depth, and the dim of context vector is hp.encoder_depth. So the concatenated dim is hp.attention_depth + hp.encoder_depth not hp.attention_depth*2 .

begeekmyfriend commented 6 years ago

@syang1993 What you are watching is the master branch and what I am saying is the tacotron2-work-in-progress branch where the attention_depth is set as 128.

syang1993 commented 6 years ago

@begeekmyfriend It's the same in the tacotron2-work-in-progress branch, I just set it as an example. For this branch, the dim is also hp.attention_depth + hp.encoder_depth*2=640 (Since it use BLSTM, so the dim of cell_output is 2*'hp.encoder_depth').

You can see the implementation of AttentionWrapper for details:

https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py

the attention_depth in this repo will only affect the query and key depth of the attention mechanism.

To make it easier to understand, suppose:

encoder_output = encoder_net(input_text)  # [batch_size, length, dim_a]
attention_cell = AttentionWrapper(
    DecoderPrenet(GRUCell(dim_b)),
    attention_mechanism(dim_c, encoder_output))
concat_cell = ConcatOutputAndAttentionWrapper(attention_cell)

Then the dim of concat_cell is dim_a + dim_b. The dim_c (which reffers to the attention_depth) doesn't affect the dims.

begeekmyfriend commented 6 years ago

@syang1993 I just printed the output_size of each cell. You are right.