omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 284 forks source link

Implementation of metrics to monitor training process in tf.keras environment #58

Closed paweller closed 3 years ago

paweller commented 3 years ago

Hello fellow developers,

First of all a big thanks to you @omoindrot for your educational article followed alongside by this repository. I am currently using the given triplet loss implementation in a TensorFlow-Keras environment, basically following @ma1112's solution and his GitHub repository he presented in #18.

I know that the triplet loss implementation presented was not natively implemented in a TensorFlow-Keras environment. Nevertheless, I feel like it is one of the best custom implementations out there for which reason I would really like to stick to it. Having said that, I am currently facing an issue. To monitor the training process I would like to use metrics like precision, recall or simply a let's call it triplet error rate (number_of(d(a,p)-d(a,n) > 0) / number_total_triplets).

From what I have tried so far it seems like the already implemented tf.keras.metrics (e.g. AUC or Precision or Recall) do not work by just specifying them in the tf.keras.model.compile funciton. I always get an error I described in this stackoverflow question.

Now, my question is: How could I for example implement a triplet error rate function in the TensorFlow-Keras environment while maintaining omoindrot's triplet loss? Here is some dummy code on what I am trying to achieve:

def triplet_error_rate(y_true, y_pred, margin=0.5)

    # Calculate distances and `basic_loss`
    distance_anchor_positive = ...
    distance_anchor_negative = ...
    basic_loss = distance_anchor_positive - distance_anchor_negative

    # Count the number of triplets where `basic_loss` > 0 and number of total triplets
    number_error_triplets = ...
    number_total_triplets = ...

    # Calculate the error
    error = number_error_triplets / number_total_triplets

    return error

In the main code I would then use this function as a metric for the training process:

model = create_cov_net()

model.compile(
    optimizer=...,
    loss=triplet_loss_keras,
    metrics=[triplet_error_rate]
)

model.fit(
    ...
)

Any help is greatly appreciated.


EDIT: Just use the fraction_positive_triplets variable from the batch_all_triplet_loss function. It does what I refered to as a triplet error rate.

def triplet_error(y_true, y_pred, margin=0.5):
    _, error = batch_all_triplet_loss(tf.squeeze(y_true), y_pred, margin)

    return error

In the main code you can than do:

siamese_net_model.compile(
    ...,
    metrics=[triplet_error]
)