tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

What values of metrics can be interpreted as successful during training Similarity Models #327

Closed sergeichukd closed 1 year ago

sergeichukd commented 1 year ago

Hello! Thank you for your job. TF Similarity is great!

In this hello world notebook, in section Training is written:

NOTE: don't expect the validation loss to decrease too much here because we only use a subset of the classes within the train data but include all classes in the validation data.

I'm a new in training of Similarity Models. Could you give me an advice, which values of validation metrics can be interpreted as successful training?

owenvallis commented 1 year ago

Hi @sergeichukd,

We provide a number of classification and information retrieval metrics that can be used during training via the EvalCallback object in tensorflow_similarity.callbacks. I general we tend to use the binary_accuracy metric as this can be thought of as the within threshold precision * the recall or total number of elements in the database. There are more details in the doc string under the classification metrics dir.

Regarding the comment. The issue here is that we train the model using a subset of the classes but include all classes in the validation dataset. This gives us a sense of how well the model will generalize to new unseen classes, but also means the validation metrics will lag behind the train metrics.

Ket me know if you have any questions about the EvalCallback object or any of the metrics. Here is a more detailed example of running some of the evaluations after the model is finished training.

# Assumes you have trained a model and that index data != query data...

# Add all examples to the index
brute_force_search = NMSLibSearch(
    distance="cosine",
    dim=model.output.shape[1],
    method='brute_force',
)
# Create or clear the index
try:
  model.reset_index()  # clear the index
except AttributeError:
  model.create_index(brute_force_search)  # or create it.

model.index(index_x, y=index_y, data=index_human_readable_data)

calibrate_metrics = model.calibrate(
    query_x,
    y=query_y,
    thresholds_targets = {"0.99": 0.99, "0.95": 0.95, "0.90": 0.90, "0.85": 0.85, "0.80": 0.80},
    calibration_metric="binary_accuracy",
)

eval_cal = model.evaluate_classification(
    query_x,
    y=query_y,
    extra_metrics=['precision', 'binary_accuracy', 'recall', 'npv', 'fpr'],
)

def make_recall_at_k(k: int) -> RecallAtK:
  return RecallAtK(k=k, average="macro")

def make_precision_at_k(k: int) -> PrecisionAtK:
  return PrecisionAtK(k=k, average="macro")

def make_map_at_r(targets_y: np.ndarray, max_class_count: int) -> MapAtK:
  class_counts = Counter(targets_y)
  max_class_count = min(max(class_counts.values()), max_class_count)
  return MapAtK(
      r=class_counts,
      clip_at_r=True,
      k=max_class_count,
      name="map@R",
  )

def make_r_precision(
    targets_y: np.ndarray, max_class_count: int
) -> PrecisionAtK:
  class_counts = Counter(targets_y)
  max_class_count = min(max(class_counts.values()), max_class_count)
  return PrecisionAtK(
      r=class_counts,
      clip_at_r=True,
      k=max_class_count,
      name="R_Precision",
  )

recall_at_k = [make_recall_at_k(k) for k in [1, 2, 4, 8, 16, 32]]
precision_at_k = [make_precision_at_k(k) for k in [1, 2, 4, 8, 16, 32]]

metrics = [
    make_map_at_r(df["label"].cat.codes.values, 300),
    make_r_precision(df["label"].cat.codes.values, 300),
]
metrics.extend(recall_at_k + precision_at_k)

eval_cal = model.evaluate_retrieval(
    query_x,
    y=query_y,
    retrieval_metrics=metrics,
)
sergeichukd commented 1 year ago

Thank you, @owenvallis! This callback is extremely helpful for me