NELSONZHAO / zhihu

This repo contains the source code in my personal column (https://zhuanlan.zhihu.com/zhaoyeyu), implemented using Python 3.6. Including Natural Language Processing and Computer Vision projects, such as text generation, machine translation, deep convolution GAN and other actual combat code.
https://zhuanlan.zhihu.com/zhaoyeyu
3.5k stars 2.14k forks source link

关于mt_attention_birnn可视化的部分 #40

Open headwheel opened 4 years ago

headwheel commented 4 years ago

作者您好,我在试您的代码时发现您写的下面这个可视化函数最终得不到您展示出来的效果图,请问是什么原因呢? def plot_attention(sentence, Tx=20, Ty=25): """ 可视化Attention层

@param sentence: 待翻译的句子,str类型
@param Tx: 输入句子的长度
@param Ty: 输出句子的长度
"""

X = np.array(text_to_int(sentence, source_vocab_to_int))
f = K.function(model.inputs, [model.layers[9].get_output_at(t) for t in range(Ty)])

s0 = np.zeros((1, n_s))
c0 = np.zeros((1, n_s))
out0 = np.zeros((1, len(target_vocab_to_int)))

r = f([X.reshape(-1,20), s0, c0, out0])

attention_map = np.zeros((Ty, Tx))
for t in range(Ty):
    for t_prime in range(Tx):
        attention_map[t][t_prime] = r[t][0, t_prime, 0]

Y = make_prediction(sentence)

source_list = sentence.split()
target_list = Y.split()

f, ax = plt.subplots(figsize=(20,15))
sns.heatmap(attention_map, xticklabels=source_list, yticklabels=target_list, cmap="YlGnBu")
ax.set_xticklabels(ax.get_xticklabels(), fontsize=15, rotation=90)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=15)

我print了一下attention_map数组的结果,发现数值全部都是0.05. image