Open ichenjia opened 5 years ago
def GetSubMask(s): len_s = tf.shape(s)[1] bs = tf.shape(s)[:1] mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1) return mask
if the input is (5,4,3)
wouldn't tf.eye here creates a lower triangle tensor of 5, 4, 4 instead of 5,4,3 because of [:1]
def GetSubMask(s): len_s = tf.shape(s)[1] bs = tf.shape(s)[:1] mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1) return mask
if the input is (5,4,3)
wouldn't tf.eye here creates a lower triangle tensor of 5, 4, 4 instead of 5,4,3 because of [:1]