Closed MarinaZhang closed 2 years ago
When using the EvalCallback on string inputs, i.e.
EvalCallback(queries, query_labels, targets, target_labels)
where queries and targets are tf.string Tensors, I get the following error:
steps_per_epoch=steps, callbacks=callbacks, validation_steps=10) File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/callbacks.py", line 210, in on_epoch_end distance_thresholds=self.distance_thresholds, File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/callbacks.py", line 359, in _compute_classification_metrics lookup_distances = unpack_lookup_distances(lookups, queries.dtype) File "/home/jupyter/envs/tf2.7/lib/python3.7/site-packages/tensorflow_similarity/utils.py", line 74, in unpack_lookup_distances ragged_dists = tf.ragged.constant(all_values, dtype=base_type) TypeError: Cannot convert [0.003274083137512207, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006584286689758301, ...] to EagerTensor of dtype string
This is because in https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/callbacks.py#L359 and then in unpack_lookup_distances https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/utils.py#L74 , the code attempts to cast the distances into the same type as the queries (floats to tf.string) which breaks. Thanks!
When using the EvalCallback on string inputs, i.e.
where queries and targets are tf.string Tensors, I get the following error:
This is because in https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/callbacks.py#L359 and then in unpack_lookup_distances https://github.com/tensorflow/similarity/blob/c22e1e7a9b9f8a5549a84019930fb4efd9b20e24/tensorflow_similarity/utils.py#L74 , the code attempts to cast the distances into the same type as the queries (floats to tf.string) which breaks. Thanks!