lsdefine / attention-is-all-you-need-keras

A Keras+TensorFlow Implementation of the Transformer: Attention Is All You Need
708 stars 188 forks source link

MultiHeadAttention #2

Closed AMSakhnov closed 6 years ago

AMSakhnov commented 6 years ago

Hi!

It is strange to have n_head == 1, but it does not work in MultiHeadAttention class (mode=1) To fix it, it is enough to change

head = Concatenate()(heads)
attn = Concatenate()(attns)

on

if n_head == 1:
    head = heads[0]
    attn = attns[0]
else:
    head = Concatenate()(heads)
    attn = Concatenate()(attns)

because

A `Concatenate` layer should be called on a list of at least 2 inputs
lsdefine commented 6 years ago

Fixed. Thank you.