omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 283 forks source link

Multi domain triplet loss #54

Closed emizzz closed 4 years ago

emizzz commented 4 years ago

Hi Omoindrot, great explanation and implementation of triplet loss.

I’m building a model that uses the online triplet loss, but in my case, the network requires 2 domains: the draws domain and the images domain. The purpose of this network is to learn a good embedding for SBIR (Sketch-based image retrieval), corresponding to filling the gap between draws and images (the 2 branches shares partially the weights).

I tried to change the function for the 2 domains as follows:

  1. I have builded a (Keras) generator that outputs 2 embeddings vectors (draws and icons) with the same class indexes:

e.g. :

X_out = [ 
  [ draw_emb_class2, draw_emb_class1, draw_emb_class3, … ],
  [ icon_emb_class2, icon_emb_class1, icon_emb_class3, … ],
]
y_out = [ 
  [ 2, 1, 3, … ]
]
  1. Then I have builded a matrix with the pairwise distances with the draw embeddings as anchors and the images as pos/neg (not the best code...). I obtained a matrix of distances as it follows:
    # d_ is the distance
    pairwise_dist = [
    d_draw1-draw1, d_draw1-image2, d_draw1-image3, 
    d_draw2-image1, d_draw2-draw2, d_draw2-image3, 
    d_draw3-image1, d_draw3-image2, d_draw3-draw3, 
    ]

    Here is the code (changes are in the "multi domain version" sections):


def batch_hard_triplet_loss_multi_domain(y_true, y_pred, margin=1, squared=False):

    # I think keras need this
    labels = tf.squeeze(y_true, axis=-1)

    # *************************multi domain version*****************************
    _batch_size = y_pred.shape[1] // 2

    # embedding = (batch, feat)
    embedding_d = y_pred[:, :_batch_size]
    embedding_i = y_pred[:, _batch_size:]

    pairwise_dist_same = multidomain_pairwise_dist(embedding_i, embedding_i)
    pairwise_dist_diff = multidomain_pairwise_dist(embedding_i, embedding_d)

    pairwise_dist = tf.linalg.set_diag(
        pairwise_dist_diff,
        tf.linalg.diag_part(                                                    # all zeros
            pairwise_dist_same,
        )
    )
    # **************************************************************************

    # For each anchor, get the hardest positive
    # First, we need to get a mask for every valid positive (they should have same label)

    # *************************multi domain version*****************************
    mask_anchor_positive = _get_anchor_positive_triplet_mask_multi_domain(labels)
    # **************************************************************************

    mask_anchor_positive = tf.compat.v1.to_float(mask_anchor_positive)

    # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
    anchor_positive_dist = tf.math.multiply(mask_anchor_positive, pairwise_dist)

    # shape (batch_size, 1)
    hardest_positive_dist = tf.math.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
    tf.summary.scalar("hardest_positive_dist", tf.math.reduce_mean(hardest_positive_dist))

    # For each anchor, get the hardest negative
    # First, we need to get a mask for every valid negative (they should have different labels)
    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = tf.compat.v1.to_float(mask_anchor_negative)

    # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
    max_anchor_negative_dist = tf.math.reduce_max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)

    # shape (batch_size,)
    hardest_negative_dist = tf.math.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
    tf.summary.scalar("hardest_negative_dist", tf.math.reduce_mean(hardest_negative_dist))

    # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
    triplet_loss = tf.math.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)

    # Get final mean triplet loss
    triplet_loss = tf.math.reduce_mean(triplet_loss)

    return triplet_loss

And these are the called functions:

def _get_anchor_positive_triplet_mask_multi_domain(labels):

    # *************************multi domain version*****************************
    mask = tf.math.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    # **************************************************************************

    return mask

def multidomain_pairwise_dist(A, B):
    """
    Computes pairwise distances between each elements of A and each elements of B.
    Args:
      A,    [m,d] matrix
      B,    [n,d] matrix
    Returns:
      D,    [m,n] matrix of pairwise distances
    """
    #with tf.variable_scope('pairwise_dist'):
    # squared norms of each row in A and B
    na = tf.reduce_sum(tf.square(A), 1)
    nb = tf.reduce_sum(tf.square(B), 1)

    # na as a row and nb as a co"lumn vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])

    # return pairwise euclidean difference matrix
    D = tf.sqrt(tf.maximum(na - 2 * tf.matmul(A, B, False, True) + nb, 0.0))
    return D

Now, do you think this approach could make sense and could work with online triplet loss?

Thank you.