CharlesShang / TFFRCNN

FastER RCNN built on tensorflow
MIT License
875 stars 418 forks source link

OHEM bug? #60

Open christopher5106 opened 7 years ago

christopher5106 commented 7 years ago

Hi ,

In the loss, sounds tf.reducesum(tf.cast(bg, tf.float32)) is not correct, it should be top_k, isn't it ?

    if ohem:
        # k = tf.minimum(tf.shape(rpn_cross_entropy_n)[0] / 2, 300)
        # # k = tf.shape(rpn_loss_n)[0] / 2
        # rpn_loss_n, top_k_indices = tf.nn.top_k(rpn_cross_entropy_n, k=k, sorted=False)
        # rpn_cross_entropy_n = tf.gather(rpn_cross_entropy_n, top_k_indices)
        # rpn_loss_box_n = tf.gather(rpn_loss_box_n, top_k_indices)

        # strategy: keeps all the positive samples
        fg_ = tf.equal(rpn_label, 1)
        bg_ = tf.equal(rpn_label, 0)
        pos_inds = tf.where(fg_)
        neg_inds = tf.where(bg_)
        rpn_cross_entropy_n_pos = tf.reshape(tf.gather(rpn_cross_entropy_n, pos_inds), [-1])
        rpn_cross_entropy_n_neg = tf.reshape(tf.gather(rpn_cross_entropy_n, neg_inds), [-1])
        top_k = tf.cast(tf.minimum(tf.shape(rpn_cross_entropy_n_neg)[0], 300), tf.int32)
        rpn_cross_entropy_n_neg, _ = tf.nn.top_k(rpn_cross_entropy_n_neg, k=top_k)
        rpn_cross_entropy = tf.reduce_sum(rpn_cross_entropy_n_neg) / (tf.reduce_sum(tf.cast(bg_, tf.float32)) + 1.0) \
                            + tf.reduce_sum(rpn_cross_entropy_n_pos) / (tf.reduce_sum(tf.cast(fg_, tf.float32)) + 1.0)