tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

convert to tensor in index prevents passing a dict of inputs #322

Closed owenvallis closed 1 year ago

owenvallis commented 1 year ago
ValueError                                Traceback (most recent call last)
<ipython-input-19-d4b57d7b51eb> in <module>()
      1 index_data = [{"prompt_input": p, "response_input": r} for p, r in zip(prompt_index, response_index)]
----> 2 model.index({"prompt_input": prompt_index, "response_input": response_index}, y=y_index, data=index_data)

2 frames
google3/third_party/py/tensorflow_similarity/models/similarity_model.py in index(self, x, y, data, build, verbose)
    350       print("|-Computing embeddings")
    351     with tf.device("/cpu:0"):
--> 352       x = tf.convert_to_tensor(np.array(x))
    353 
    354     predictions = self.predict(x)

this change was required to prevent a slowdown and possible memory leak when passing lists of inputs instead of np.array or tensors. However, this breaks passing multiple inputs.

We should add a type check first and handle the multi input case properly.

owenvallis commented 1 year ago

Removing all tf.convert_to_tensor() calls before predict. While the previous change prevented the memory leak in the case where we called multiple models in a loop, it ended up restricting calls to predict to a single tensor batch. This is too restrictive and prevents us from calling multi-headed models.