Hi Omoindrot,
great explanation and implementation of triplet loss.
I’m building a model that uses the online triplet loss, but in my case, the network requires 2 domains: the draws domain and the images domain. The purpose of this network is to learn a good embedding for SBIR (Sketch-based image retrieval), corresponding to filling the gap between draws and images (the 2 branches shares partially the weights).
I tried to change the function for the 2 domains as follows:
I have builded a (Keras) generator that outputs 2 embeddings vectors (draws and icons) with the same class indexes:
Then I have builded a matrix with the pairwise distances with the draw embeddings as anchors and the images as pos/neg (not the best code...). I obtained a matrix of distances as it follows:
# d_ is the distance
pairwise_dist = [
d_draw1-draw1, d_draw1-image2, d_draw1-image3,
d_draw2-image1, d_draw2-draw2, d_draw2-image3,
d_draw3-image1, d_draw3-image2, d_draw3-draw3,
]
Here is the code (changes are in the "multi domain version" sections):
def batch_hard_triplet_loss_multi_domain(y_true, y_pred, margin=1, squared=False):
# I think keras need this
labels = tf.squeeze(y_true, axis=-1)
# *************************multi domain version*****************************
_batch_size = y_pred.shape[1] // 2
# embedding = (batch, feat)
embedding_d = y_pred[:, :_batch_size]
embedding_i = y_pred[:, _batch_size:]
pairwise_dist_same = multidomain_pairwise_dist(embedding_i, embedding_i)
pairwise_dist_diff = multidomain_pairwise_dist(embedding_i, embedding_d)
pairwise_dist = tf.linalg.set_diag(
pairwise_dist_diff,
tf.linalg.diag_part( # all zeros
pairwise_dist_same,
)
)
# **************************************************************************
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
# *************************multi domain version*****************************
mask_anchor_positive = _get_anchor_positive_triplet_mask_multi_domain(labels)
# **************************************************************************
mask_anchor_positive = tf.compat.v1.to_float(mask_anchor_positive)
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = tf.math.multiply(mask_anchor_positive, pairwise_dist)
# shape (batch_size, 1)
hardest_positive_dist = tf.math.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
tf.summary.scalar("hardest_positive_dist", tf.math.reduce_mean(hardest_positive_dist))
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
mask_anchor_negative = tf.compat.v1.to_float(mask_anchor_negative)
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist = tf.math.reduce_max(pairwise_dist, axis=1, keepdims=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist = tf.math.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
tf.summary.scalar("hardest_negative_dist", tf.math.reduce_mean(hardest_negative_dist))
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
triplet_loss = tf.math.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
# Get final mean triplet loss
triplet_loss = tf.math.reduce_mean(triplet_loss)
return triplet_loss
And these are the called functions:
def _get_anchor_positive_triplet_mask_multi_domain(labels):
# *************************multi domain version*****************************
mask = tf.math.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
# **************************************************************************
return mask
def multidomain_pairwise_dist(A, B):
"""
Computes pairwise distances between each elements of A and each elements of B.
Args:
A, [m,d] matrix
B, [n,d] matrix
Returns:
D, [m,n] matrix of pairwise distances
"""
#with tf.variable_scope('pairwise_dist'):
# squared norms of each row in A and B
na = tf.reduce_sum(tf.square(A), 1)
nb = tf.reduce_sum(tf.square(B), 1)
# na as a row and nb as a co"lumn vectors
na = tf.reshape(na, [-1, 1])
nb = tf.reshape(nb, [1, -1])
# return pairwise euclidean difference matrix
D = tf.sqrt(tf.maximum(na - 2 * tf.matmul(A, B, False, True) + nb, 0.0))
return D
Now, do you think this approach could make sense and could work with online triplet loss?
Hi Omoindrot, great explanation and implementation of triplet loss.
I’m building a model that uses the online triplet loss, but in my case, the network requires 2 domains: the draws domain and the images domain. The purpose of this network is to learn a good embedding for SBIR (Sketch-based image retrieval), corresponding to filling the gap between draws and images (the 2 branches shares partially the weights).
I tried to change the function for the 2 domains as follows:
e.g. :
Here is the code (changes are in the "multi domain version" sections):
And these are the called functions:
Now, do you think this approach could make sense and could work with online triplet loss?
Thank you.