Kyubyong / transformer

A TensorFlow Implementation of the Transformer: Attention Is All You Need
Apache License 2.0
4.25k stars 1.29k forks source link

About the query_mask #182

Open bjuthjliu opened 2 years ago

bjuthjliu commented 2 years ago

Source Code:

    padding_num = -2 ** 32 + 1
    if type in ("k", "key", "keys"):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num

I think the outputs should be:

    padding_num = -2 ** 32 + 1
    if type in ("k", "key", "keys"):
        key_masks = tf.to_float(key_masks) # (N, T_k)
        key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(key_masks)[1], 1]) # (h*N, T_q, seqlen)
        paddings = tf.ones_like(key_masks) * padding_num
        outputs = tf.where(tf.equal(key_masks, 0), paddings, inputs)