tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

Does SplitValidationLoss assume presence of unknown class by design? #215

Open kechan opened 2 years ago

kechan commented 2 years ago

It appears that this callback:

val_loss = SplitValidationLoss(queries_x, queries_y, targets_x, targets_y, 
                     metrics=['binary_accuracy'], 
                     known_classes=tf.constant(range(len(classes))), 
                     k=k, 
                     #tb_logdir=log_dir  # uncomment if you want to track in tensorboard
 )

will generate a "invalid length" kind of error. I browsed the code a little bit and figured that the known_class must be a subset of all the classes such that there's at least 1 unknown class (it doesn't handle 0 unknown class). I wonder if this is by design since the name has Split in it.

This is what the error looks like:

/usr/local/lib/python3.7/dist-packages/tensorflow_similarity/models/similarity_model.py in lookup(self, x, k, verbose)
    410             List[List[Lookup]]
    411         """
--> 412         predictions = self.predict(x)
    413         return self._index.batch_lookup(
    414             predictions=predictions, k=k, verbose=verbose

ValueError: Unexpected result of `predict_function` (Empty batch_outputs). Please use `Model.compile(..., run_eagerly=True)`, or `tf.config.run_functions_eagerly(True)` for more information of where went wrong, or file a issue/bug to `tf.keras`.
owenvallis commented 2 years ago

That's correct, I added the SplitValidationLoss to monitor the matching classification performance on the known vs the unknown classes to get a sense of how well an embedding is generalizing during training. I would recommend using the EvalCallback if you don't have any Unknown classes though. The two callbacks are very similar, so much so that I'm thinking of combining them at some point into a more general callback that also supports the retrieval metrics.

Can you share the error msg you were getting? I can likely add a check for the length and/or a more informative error msg to that.

kechan commented 2 years ago

@owenvallis I pasted it in my original post (maybe for better discovery if someone searches for it).