Open SapereAudo opened 1 year ago
class HardNegativeMining(tf.keras.layers.Layer): """Hard Negative"""
def __init__(self, num_hard_negatives: int, **kwargs): super(HardNegativeMining, self).__init__(**kwargs) self._num_hard_negatives = num_hard_negatives def call(self, logits: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: num_sampled = tf.minimum(self._num_hard_negatives + 1, tf.shape(logits)[1]) _, indices = tf.nn.top_k(logits + labels * MAX_FLOAT, k=num_sampled, sorted=False) logits = _gather_elements_along_row(logits, indices) labels = _gather_elements_along_row(labels, indices) return logits, labels
class HardNegativeMining(tf.keras.layers.Layer): """Hard Negative"""