localminimum / R-net

A Tensorflow Implementation of R-net: Machine reading comprehension with self matching networks
MIT License
323 stars 122 forks source link

Would u mind explaining the `cross_entropy_with_sequence_mask` to me? #16

Closed Levstyle closed 7 years ago

Levstyle commented 7 years ago
def cross_entropy_with_sequence_mask(output, target):
    cross_entropy = target * tf.log(output + 1e-8) # B x 2 x N
    cross_entropy = -tf.reduce_sum(cross_entropy, 2) # B x 2
    mask = tf.sign(tf.reduce_max(tf.abs(target), 2)) # B x 2
    cross_entropy *= mask # B x 2
    cross_entropy = tf.reduce_sum(cross_entropy, 1) # B
    cross_entropy /= tf.reduce_sum(mask, 1) # B
    return tf.reduce_mean(cross_entropy) # 1

when I read the code above, I find it's hard to understand the function of mask.

to my knowledge, the shape of output and target are both (B, 2, N), it means we have B samples, and have to predict the starting position and the end position for each sample, every position has N options.

but the actual lengths of most samples are less than N. So we have to mask the extra positions out. As stated,the mask should work at axis 2, not axis 1.

ghost commented 7 years ago

Hi @Levstyle thanks for putting this up here! You're right about the dimensions. It looks like right now the cost function is as good as not having a mask. I will fix this and report the resulting improvement.

huangpeng1126 commented 7 years ago
def cross_entropy_with_sequence_mask(output, target):
    cross_entropy = target * tf.log(output + 1e-8) # B x 2 x N
    cross_entropy = -tf.reduce_sum(cross_entropy, 2) # B x 2
    mask = tf.sign(tf.reduce_max(tf.abs(target), 2)) # B x 2
    cross_entropy *= mask # B x 2
    cross_entropy = tf.reduce_sum(cross_entropy, 1) # B
    cross_entropy /= tf.reduce_sum(mask, 1) # B
    return tf.reduce_mean(cross_entropy) # 1

I don't understand the reason using mask. BTW, I guess that tf.tf.losses.softmax_cross_entropy can replace the above function as follows:

cross_entropy = tf.losses.softmax_cross_entropy(tf.reshape(target, (-1, N, ...)

But I'm not sure how does mask work. I will appreciate your explanation.

ghost commented 7 years ago

@huangpeng1126 I actually noticed this mistake but didn't have time to commit a fix. Because we are already multiplying the logits with the label which is essentially just a one hot vector, we don't need masking at all. I will be removing the mask function as we don't want to mask cross entropy but rather the logits output by the pointer network before softmax (see here). But for practical purpose, the current code for cross_entropy_with_sequence_mask function works perfectly fine, since it is just summing up across passage_length dimension, and multiplying by mask which doesn't make a difference at all, and then summing again across the time dimension of the pointer network and dividing it by 2 which is identical to taking the average of cross entropy. Please feel free to correct me if I'm wrong. Also the reason I can't use tf.tf.losses.softmax_cross_entropy is because I'm already applying softmax at pointer network before outputting the logits. (Check here and here)

ghost commented 7 years ago

Fixed as of 9997f7506d8c9bfe27ea49a817591e7de30b31a8. Thanks!

huangpeng1126 commented 7 years ago

@minsangkim142

but rather the logits output by the pointer network before softmax

That is great, and it seems better idea.

Also the reason I can't use tf.tf.losses.softmax_cross_entropy is because I'm already applying softmax at pointer network before outputting the logits.

Thank for your explanation for that, you are right, softmax_cross_entropy is wrong with softmax inputs.

BTW, your code is very pretty, and I learned lots of coding tricks