tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.85k stars 278 forks source link

Unable to use query_with_exclusions with undefined length of exclusions #517

Open TPME opened 2 years ago

TPME commented 2 years ago

Hi, I am trying to save a model with query_with_exclusions with an indefinite length of exclusions (because every user might have seen more or fewer posts....):

recommendations_k_21 = tf.function(lambda user_id: final_model(queries=user_id, k=21))
recommendations_k_21_exclusions = tf.function(lambda user_id, exclusions: final_model.query_with_exclusions(queries=user_id, k=21, exclusions=exclusions))

# Call them to create concrete functions.
_, ids = recommendations_k_21(tf.constant("b6bfd5af-d21e-4726-860c-dc663c34e20f"))
_, ids = recommendations_k_21_exclusions(tf.constant(["b6bfd5af-d21e-4726-860c-dc663c34e20f"]), tf.constant([["1234"]]))

signatures = {"recommendations_k_21": recommendations_k_21.get_concrete_function(user_id=tf.TensorSpec(shape=(), dtype=tf.string)), "recommendations_k_21_exclusions": recommendations_k_21_exclusions.get_concrete_function(user_id=tf.TensorSpec(shape=(1,), dtype=tf.string), exclusions=tf.TensorSpec(shape=(1, None), dtype=tf.string))}

tf.saved_model.save(final_model, MODEL_STORAGE_PATH, signatures)

However, I get the following error:

WARNING:tensorflow:Model was constructed with shape (None,) for input KerasTensor(type_spec=TensorSpec(shape=(None,), dtype=tf.string, name='string_lookup_input'), name='string_lookup_input', description="created by layer 'string_lookup_input'"), but it was called on an input with incompatible shape ().

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [275], in <cell line: 8>()
      5 _, ids = recommendations_k_21(tf.constant("b6bfd5af-d21e-4726-860c-dc663c34e20f"))
      6 _, ids = recommendations_k_21_exclusions(tf.constant(["b6bfd5af-d21e-4726-860c-dc663c34e20f"]), tf.constant([["1234"]]))
----> 8 signatures = {"recommendations_k_21": recommendations_k_21.get_concrete_function(user_id=tf.TensorSpec(shape=(), dtype=tf.string)), "recommendations_k_21_exclusions": recommendations_k_21_exclusions.get_concrete_function(user_id=tf.TensorSpec(shape=(1,), dtype=tf.string), exclusions=tf.TensorSpec(shape=(1, None), dtype=tf.string))}
     10 tf.saved_model.save(final_model, MODEL_STORAGE_PATH, signatures)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py:1239, in Function.get_concrete_function(self, *args, **kwargs)
   1237 def get_concrete_function(self, *args, **kwargs):
   1238   # Implements GenericFunction.get_concrete_function.
-> 1239   concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1240   concrete._garbage_collector.release()  # pylint: disable=protected-access
   1241   return concrete

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py:1230, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
   1225   return self._stateless_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
   1226       *args, **kwargs)
   1227 elif self._stateful_fn is not None:
   1228   # In this case we have not created variables on the first call. So we can
   1229   # run the first trace but we should fail if variables are created.
-> 1230   concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
   1231       *args, **kwargs)
   1232   if self._created_variables:
   1233     raise ValueError("Creating variables on a non-first call to a function"
   1234                      " decorated with tf.function.")

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/function.py:2533, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
   2531   args, kwargs = None, None
   2532 with self._lock:
-> 2533   graph_function, _ = self._maybe_define_function(args, kwargs)
   2534   seen_names = set()
   2535   captured = object_identity.ObjectIdentitySet(
   2536       graph_function.graph.internal_captures)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/function.py:2711, in Function._maybe_define_function(self, args, kwargs)
   2708   cache_key = self._function_cache.generalize(cache_key)
   2709   (args, kwargs) = cache_key._placeholder_value()  # pylint: disable=protected-access
-> 2711 graph_function = self._create_graph_function(args, kwargs)
   2712 self._function_cache.add(cache_key, cache_key_deletion_observer,
   2713                          graph_function)
   2715 return graph_function, filtered_flat_args

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/function.py:2627, in Function._create_graph_function(self, args, kwargs)
   2622 missing_arg_names = [
   2623     "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
   2624 ]
   2625 arg_names = base_arg_names + missing_arg_names
   2626 graph_function = ConcreteFunction(
-> 2627     func_graph_module.func_graph_from_py_func(
   2628         self._name,
   2629         self._python_function,
   2630         args,
   2631         kwargs,
   2632         self.input_signature,
   2633         autograph=self._autograph,
   2634         autograph_options=self._autograph_options,
   2635         arg_names=arg_names,
   2636         capture_by_value=self._capture_by_value),
   2637     self._function_attributes,
   2638     spec=self.function_spec,
   2639     # Tell the ConcreteFunction to clean up its graph once it goes out of
   2640     # scope. This is not the default behavior since it gets used in some
   2641     # places (like Keras) where the FuncGraph lives longer than the
   2642     # ConcreteFunction.
   2643     shared_func_graph=False)
   2644 return graph_function

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py:1141, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, acd_record_initial_resource_uses)
   1138 else:
   1139   _, original_func = tf_decorator.unwrap(python_func)
-> 1141 func_outputs = python_func(*func_args, **func_kwargs)
   1143 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
   1144 # TensorArrays and `None`s.
   1145 func_outputs = nest.map_structure(
   1146     convert, func_outputs, expand_composites=True)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py:677, in Function._defun_with_scope.<locals>.wrapped_fn(*args, **kwds)
    673 with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
    674   # __wrapped__ allows AutoGraph to swap in a converted function. We give
    675   # the function a weak reference to itself to avoid a reference cycle.
    676   with OptionalXlaContext(compile_with_xla):
--> 677     out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    678   return out

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py:1127, in func_graph_from_py_func.<locals>.autograph_handler(*args, **kwargs)
   1125 except Exception as e:  # pylint:disable=broad-except
   1126   if hasattr(e, "ag_error_metadata"):
-> 1127     raise e.ag_error_metadata.to_exception(e)
   1128   else:
   1129     raise

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py:1116, in func_graph_from_py_func.<locals>.autograph_handler(*args, **kwargs)
   1114 # TODO(mdan): Push this block higher in tf.function's call stack.
   1115 try:
-> 1116   return autograph.converted_call(
   1117       original_func,
   1118       args,
   1119       kwargs,
   1120       options=autograph.ConversionOptions(
   1121           recursive=True,
   1122           optional_features=autograph_options,
   1123           user_requested=True,
   1124       ))
   1125 except Exception as e:  # pylint:disable=broad-except
   1126   if hasattr(e, "ag_error_metadata"):

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File /tmp/__autograph_generated_fileh578f4h6.py:5, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(user_id, exclusions)
      4 def inner_factory(ag__):
----> 5     tf__lam = lambda user_id, exclusions: ag__.with_function_scope(lambda lscope: ag__.converted_call(final_model.query_with_exclusions, (), dict(queries=user_id, k=21, exclusions=exclusions), lscope), 'lscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True))
      6     return tf__lam

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/autograph/core/function_wrappers.py:113, in with_function_scope(thunk, scope_name, options)
    111 """Inline version of the FunctionScope context manager."""
    112 with FunctionScope('lambda_', scope_name, options) as scope:
--> 113   return thunk(scope)

File /tmp/__autograph_generated_fileh578f4h6.py:5, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(lscope)
      4 def inner_factory(ag__):
----> 5     tf__lam = lambda user_id, exclusions: ag__.with_function_scope(lambda lscope: ag__.converted_call(final_model.query_with_exclusions, (), dict(queries=user_id, k=21, exclusions=exclusions), lscope), 'lscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True))
      6     return tf__lam

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py:331, in converted_call(f, args, kwargs, caller_fn_scope, options)
    329 if conversion.is_in_allowlist_cache(f, options):
    330   logging.log(2, 'Allowlisted %s: from cache', f)
--> 331   return _call_unconverted(f, args, kwargs, options, False)
    333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    334   logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py:458, in _call_unconverted(f, args, kwargs, options, update_cache)
    455   return f.__self__.call(args, kwargs)
    457 if kwargs is not None:
--> 458   return f(*args, **kwargs)
    459 return f(*args)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File /tmp/__autograph_generated_filef8r31nzk.py:34, in outer_factory.<locals>.inner_factory.<locals>.tf__query_with_exclusions(self, queries, exclusions, k)
     32 retval_ = ag__.UndefinedReturnValue()
     33 k = ag__.if_exp(ag__.ld(k) is not None, lambda : ag__.ld(k), lambda : ag__.ld(self)._k, 'k is not None')
---> 34 adjusted_k = ag__.ld(k) + ag__.ld(exclusions).shape[1]
     35 (x, y) = ag__.converted_call(ag__.ld(self), (), dict(queries=ag__.ld(queries), k=ag__.ld(adjusted_k)), fscope)
     36 try:

TypeError: in user code:

    File "/tmp/ipykernel_162/1740288157.py", line 2, in None  *
        lambda user_id, exclusions: final_model.query_with_exclusions(queries=user_id, k=21, exclusions=exclusions)
    File "/opt/conda/lib/python3.10/site-packages/tensorflow_recommenders/layers/factorized_top_k.py", line 287, in query_with_exclusions  *
        adjusted_k = k + exclusions.shape[1]

    TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

What is my mistake here? I got the idea from https://github.com/tensorflow/recommenders/issues/298#issuecomment-918412784.

kylemcmearty commented 2 years ago

Im wondering if your problem is since you did not define anything in the shape? signatures = {"recommendations_k_21": recommendations_k_21.get_concrete_function(user_id=tf.TensorSpec(shape=(), dtype=tf.string)),