Open samehraban opened 5 years ago
I was facing the same issue with hard batching. I lowered the learning rate and made the change in hard batch function suggested in https://github.com/omoindrot/tensorflow-triplet-loss/issues/18#issuecomment-455031523 It started working. In my training, I saw that the loss gets stuck at the margin for a while and then it starts going down again.
I have a similar problem. I am using the batch-all strategy and my fraction of positive triplets is always 1.0. The loss function oscillates around the margin, but the fraction of positive triplets do not decrease. I've tried lowering the learning rate, but nothing happens. I've also checked that the embeddings do not collapse to a point.
Edit: I solved the problem. I was giving tensors with wrong shape to the function batch_all_triplet_loss().
@laiadc I am struggling with the same issue, how do you solve it? I dont think that anything is wrong with the tensor I give. I am applying the code on my own dataset.
@laiadc: Thanks for hinting towards the tensors' shapes. That was the issue for me as well.
@kurehawaru: Here is what helped me to solve the issue.
First I inserted these three lines at the very beginning of the batch_all_triplet_loss
function:
print()
print('labels.shape: {}'.format(labels.shape))
print('embeddings.shape: {}'.format(embeddings.shape))
The ouput I got was the following:
labels.shape: (None, 1)
embeddings.shape: (None, 1623)
(None
is simply to be interpreted as a placeholder for whatever your batch_size
is.)
Since the batch_all_triplet_loss
function needs the labels to be of shape (None,)
I needed to add a tf.squeeze(labels)
. It can be added either directly in the batch_all_triplet_loss
function right at its beginning or at the point where it gets called (like I did). For instance:
def batch_all_triplet_loss_caller(y_true, y_pred, margin=0.5):
loss, error = batch_all_triplet_loss(tf.squeeze(y_true), y_pred, margin)
return loss, error
Hope that helps.
I'm trying to use triplet loss train resnet on arbitrary datasets but I can't. When I train mnist using your repo, I see fraction_positive decreasing over time but then, for resnet it goes up instead. batch all and batch hard losses start way above margin and decrease over time to margin, making mean distance to be about zero.
First I thought my data is not sufficient for triplet loss but when I saw the very same pattern for mnist, I think there might be a problem somewhere.