LongmaoTeamTf / deep_recommenders

Deep Recommenders
Apache License 2.0
324 stars 108 forks source link

这段挖掘hard 负样本的方法不太懂,还请大佬赐教 #17

Open SapereAudo opened 1 year ago

SapereAudo commented 1 year ago

class HardNegativeMining(tf.keras.layers.Layer): """Hard Negative"""

def __init__(self, num_hard_negatives: int, **kwargs):
    super(HardNegativeMining, self).__init__(**kwargs)

    self._num_hard_negatives = num_hard_negatives

def call(self, logits: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    num_sampled = tf.minimum(self._num_hard_negatives + 1, tf.shape(logits)[1])

    _, indices = tf.nn.top_k(logits + labels * MAX_FLOAT, k=num_sampled, sorted=False)

    logits = _gather_elements_along_row(logits, indices)
    labels = _gather_elements_along_row(labels, indices)

    return logits, labels