tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.39k stars 1.96k forks source link

Where are the attention vectors built? #281

Open lvcasgm opened 6 years ago

lvcasgm commented 6 years ago

I'm trying to modify this project's code in order to simplify how its attention mechanism works. I don't want to get rid of the attention mechanism (which would be easy by just setting --attention=""). What I want is to find where the attention vectors are produced in order to modify them before they are used by the network to produce an output.

I'm able to get the attention matrix from _create_attention_images_summary in attention_model.py but (correct me if I'm wrong) I believe this is "the final result", and has already been used to produce an output, which means I can't modify it and have the network use my modified attention vectors.

I am using a unidirectional setup and scaled_luong attention (I could use a different attention if it makes modifying the attention vectors easier).

In which file/function should I take a look at in order to be able to modify the attention vectors before they are used by the decoder?

Thank you :)

martingrm commented 6 years ago

I think you are looking for the function _compute_attention contained in the contrib/seq2seq/attention_wrapper.py file. It calls the attention mechanism function that computes the alignments for each "word". In your case, as you are using scaled_luong attention, it will call the __call__ method of the LuongAttention class, which will call the _luong_score function in order to compute the alignments.

To summarise, if you want to modify the alignment vectors, you should take a look at the contrib/seq2seq/attention_wrapper.py file, more specifically at the end of the file, where the _compute_alignments function is called. You can check the position of the current word in the state.time variable.

If I were you, I would create a new custom AttentionMechanism which I would leave almost empty, and I would place your new method for computing the alignments in the __call__ method of this new class.