lsdefine / attention-is-all-you-need-keras

A Keras+TensorFlow Implementation of the Transformer: Attention Is All You Need
708 stars 188 forks source link

dimension in GetSubMask #23

Open ichenjia opened 5 years ago

ichenjia commented 5 years ago

def GetSubMask(s): len_s = tf.shape(s)[1] bs = tf.shape(s)[:1] mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1) return mask

if the input is (5,4,3)

wouldn't tf.eye here creates a lower triangle tensor of 5, 4, 4 instead of 5,4,3 because of [:1]