tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

Lifted Structured loss #102

Open ebursztein opened 3 years ago

ebursztein commented 3 years ago

Implement the Lifted Structured Loss - https://arxiv.org/abs/1511.06452 https://www.tensorflow.org/addons/api_docs/python/tfa/losses/LiftedStructLoss

abhisharsinha commented 2 years ago

Hi @owenvallis, if anyone hasn't already started working on it, can I attempt this?

owenvallis commented 2 years ago

Hi @abhisharsinha, that would be great!

I started an implementation based on the one in tf addons:

https://github.com/tensorflow/similarity/blob/new_losses/tensorflow_similarity/losses/lifted_struct_loss.py

but I haven't checked it for correctness.

Lorenzobattistela commented 1 year ago

Hi @owenvallis , I wonder if this is still an Issue. I'm looking forward to contribute with it. I'm doing some research on paper and material. I see that an implementation is started, is it just revise it for correctness? Thanks in advance

owenvallis commented 1 year ago

Hi @Lorenzobattistela, thanks for offering to pick this up! This is still currently open. The "initial" version I worked on is still in this file, but I never verified the implementation. One note is that I flipped the loss to use distance rather than similarity.

It would be great if you'd like to pick this up and add it to the losses. Feel free to make a PR to the development branch and assign me as the reviewer.

Lorenzobattistela commented 1 year ago

I'm studying it (still a beginner here), but working on it!

Lorenzobattistela commented 1 year ago

Hi @owenvallis , wanted to report some progress and ask some questions. Please correct me if I'm wrong at any assumption. I've just reviewed the code related to the article. The equation for calculating the loss per each batch is defined as: image

Where J~i,j is defined here: image

P is the set of positive pairs (belonging to the same class) and N is the set of negative pairs (dissimilar). Alpha is the margin, and the summed Di,j in the second equation is pairwise distances.

So, in this snippet we are calculating the Di,j and the value inside of the exponential which is the difference of alpha and these distances:

pairwise_distances = distance(embeddings)
diff = margin - pairwise_distances

Then we use masks to identify pairs with the same or different labels. Tiles are used to perform computations between each pair of embeddings in the batch, as required.

Now we calculate the exponential expression and reshape it.

loss_exp_left = tf.math.multiply(
        tf.math.exp(diff_tiled - max_elements_vect),
        neg_mask_tiled
    )

loss_exp_left = tf.reshape(
        tf.math.reduce_sum(loss_exp_left, 1, keepdims=True,),
        [batch_size, batch_size],
    )

We sum both exponentials, add up the max_elementsand pairwise_distances(which I assume is represented by Di,j summed term in the equation) with the log of our previous calculations.

I'm a little lost with the following snippet:

# *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2.
    num_positives = tf.math.reduce_sum(positive_mask) / 2.0

    lifted_loss = tf.math.truediv(
        0.25
        * tf.math.reduce_sum(
            tf.math.square(
                tf.math.maximum(tf.math.multiply(loss_mat, positive_mask), 0.0)
            )
        ),
        num_positives,
    )

I understand the square of the maximum and the multiplication by 0, but I'm not really sure about why are we calling 0.25 (1/2^2? but why loss^2?)

Started to build some tests on it today. Another question is, you mentioned you used distance instead of similarity. Should we change this to use similarity then?

Thanks in advance

owenvallis commented 1 year ago

Hi @Lorenzobattistela, I had another look at the paper and I think we can greatly simplify the implementation I shared. Here are a few thoughts:

Now the only issue here is that we will compute the loss for each positive pair twice per batch, but I think that is taken care of by the 1 / (2|P|) term.

One other tricky piece is how to vectorize over the two sets of negatives per example in the batch. It would be cool if we could reorder the pairwise distance rows and negative masks based on the column idx of the positive pairs (I think we can use TF Gather here) and then concat the two pairwise distance matrices along axis=1 (same for the negative masks). Then you could just pass that to the logsumexp function.

Lorenzobattistela commented 1 year ago

Finally got some progress on it. I'll just do some final adjusts and open the PR

Lorenzobattistela commented 1 year ago

This issue may be closed as well