tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

Casting error in EvalCallback for tf.string input queries/targets #263

Closed MarinaZhang closed 2 years ago

MarinaZhang commented 2 years ago

When using the EvalCallback on string inputs, i.e.

EvalCallback(queries, query_labels, targets, target_labels)

where queries and targets are tf.string Tensors, I get the following error:

    steps_per_epoch=steps, callbacks=callbacks, validation_steps=10)
  File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/callbacks.py", line 210, in on_epoch_end
    distance_thresholds=self.distance_thresholds,
  File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/callbacks.py", line 359, in _compute_classification_metrics
    lookup_distances = unpack_lookup_distances(lookups, queries.dtype)
  File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/utils.py", line 74, in unpack_lookup_distances
    ragged_dists = tf.ragged.constant(all_values, dtype=base_type)
TypeError: Cannot convert [0.003274083137512207,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006584286689758301, ...] to EagerTensor of dtype string

This is because in https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/callbacks.py#L359 and then in unpack_lookup_distances https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/utils.py#L74 , the code attempts to cast the distances into the same type as the queries (floats to tf.string) which breaks. Thanks!