omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 283 forks source link

Classification using triplet loss embeddings #5

Open xiaahui opened 6 years ago

xiaahui commented 6 years ago

Thank you for you tutorial and implementation of triplet loss. I have one questions about how to use the triplet loss for classification. If we need to do the classification, we could use the "embeddings" and then take the "embeddings" as feature then train model like SVM,GBDT to classify? This way the triplet-loss works like a feature extractor. Do you have any other suggestions for classification? Thanks.

omoindrot commented 6 years ago

As you say triplet loss is one technique to learn a good embedding. The whole field is called representation learning, and the goal is to learn a good representation (the embedding) of an input (the image).

If you have already trained a good model with triplet loss, you can definitely use it as a feature extractor and then feed the features to a simpler model.

Usually triplet loss is used because the embedding space is discriminatory: two embeddings of the same person should always be closer than two embeddings of distinct people.
Using that knowledge, you can do face recognition for instance by comparing a new image of a face to all the training faces you have, and give it the label of the closest embedding.

To put it another way, you perform a sort of k-NN with k=1 on the embeddings to find the label of a new face.

XZNWU commented 6 years ago

how to use the triplet loss for classification.Do you have any code to share?

HuaZheLei commented 6 years ago

You can just run a kNN algorithm here. For instance, in the embedding space, you could run a 3-NN search. Then, you will get the labels of the 3-NN's results. And just take the majority.

bzhong2 commented 6 years ago

Since estimator.predict can only return the embeddings but not the labels, I have to use the following codes to get the labels of the training data to do 3-nn search. However, I got the following error: GraphDef cannot be larger than 2GB. Is there any better way for classification?

dataset = cifar_dataset.train()
dataset = dataset.map(lambda img, lab: lab)
dataset = dataset.batch(train_length)
labels_tensor = dataset.make_one_shot_iterator().get_next()
train_labels = sess.run(labels_tensor)
omoindrot commented 6 years ago

One potential issue is that you create a batch on the full CIFAR dataset, which might be too big.

Maybe try smaller batches and concatenate the resulting labels:

dataset = cifar_dataset.train()
dataset = dataset.map(lambda img, lab: lab)
dataset = dataset.batch(128)

labels_tensor = dataset.make_one_shot_iterator().get_next()

train_labels = []
try:
    while True:
        train_labels.append(sess.run(labels_tensor))
except tf.errors.OutOfRangeError:
    pass

train_labels = np.concatenate(train_labels)
HuaZheLei commented 6 years ago

@omoindrot We can learn embeddings through triplet loss, but we can not get the label at once. Therefore, I wonder if you can add some FC layers with a softmax layer after triplet loss so that the softmax layer could learn a classification plane and output the score at once.

omoindrot commented 6 years ago

You could for instance train a good embedding with triplet loss, then freeze this network and add one or more fully connected with a final softmax.

image --> frozen network --> embedding --> FC --> FC --> softmax

However this is in the case where you know in advance the number of classes you have.

Triplet loss is often used in projects where you don't know the number of classes, such as in face recognition.
With a triplet loss trained embedding, you can easily check if two faces are close together or not, and have a threshold to indicate whether they belong to the same person or not.

jbieliauskas commented 6 years ago

What threshold should you use to compare two embeddings? Is it the margin hyperparameter you used in training?

I posted this question on stats.stackexchange.com and datascience.stackexchange.com

omoindrot commented 6 years ago

@justasbieliauskas : yes the margin parameter can be used as a threshold to classify two images as belonging to the same class or not.

bartoszpawlowicz commented 6 years ago

@omoindrot @justasbieliauskas I do not agree that you can treat the margin as the threshold, because margin defines the target difference of distances between positive and negative pairs (Dap - Dan) not the distance intself. Experimenting with few different NN architectures I found that there are different optimal thresholds for different architectures. This can especially vary if you normalise the output embeddings vector and/or if you use relu at the last (embeddings) layer (which you should not do)

omoindrot commented 6 years ago

Good point @bartoszpawlowicz