tensorflow / serving

A flexible, high-performance serving system for machine learning models
https://www.tensorflow.org/serving
Apache License 2.0
6.17k stars 2.19k forks source link

[ TF 2.0 ] tf.function throws error when trying to export model #1698

Closed ayushch3 closed 4 years ago

ayushch3 commented 4 years ago

System information

Describe the current behavior I ported n gram computation for truecasing to tf-2.0 and it works fine when I run the function as standalone but fails when I try to wrap the function with tf.function. There is a get_score function that creates unigram/bigram combinations. When the function runs standalone with string tensor as input, everything works as expected

Describe the expected behavior

The tf.function wrapper is needed so that the model can be exported to run with tf.serving. The code functions correctly when its provided an input. But the code fails with the error attached when tf.function wrapper is added

Standalone code to reproduce the issue

    def get_score(self, prev_token, possible_token, next_token):
        possible_token_l = tf.strings.lower(possible_token)
        alternative_tokens = self.get_alternative_tokens(possible_token_l)

        unigram_score = self.compute_unigram_score(possible_token, alternative_tokens)
        result = tf.math.log(unigram_score)

        if prev_token is not None:
            bigram_backward_score = self.compute_bigram_backward_score(possible_token, prev_token, alternative_tokens)
            result += tf.math.log(bigram_backward_score)

        if next_token is not None:
            bigram_forward_score = self.compute_bigram_forward_score(possible_token, next_token, alternative_tokens)
            result += tf.math.log(bigram_forward_score)

        if prev_token is not None and next_token is not None:
            trigram_score = self.compute_trigram_score(possible_token, prev_token, next_token, alternative_tokens)
            result += tf.math.log(trigram_score)
        return result

    @tf.function(input_signature=[tf.TensorSpec(shape=(None), dtype=tf.string, name="input_text")])
    def get_true_case(self, tokens_tensor):
        cap_first_token = tf.reshape(tokens_tensor[0], [1])
        trueCasedTokens = cap_first_token
        tokens_tensor = tf.slice(tokens_tensor, [1], [-1])

        condition = lambda tokens_tensor, trueCasedTokens: tf.greater(tf.size(tokens_tensor), 0)

        def body(tokens_tensor, trueCasedTokens):
            cur_tokens_tensor = tf.concat([trueCasedTokens[-1:], tokens_tensor], 0)
            curToken = tf.get_static_value(tf.slice(cur_tokens_tensor, [1], [1]))[0]
            curToken = tf.get_static_value(tf.strings.lower(curToken))

            prevToken = tf.slice(cur_tokens_tensor, [0], [1])
            if tf.get_static_value(tf.greater(tf.size(cur_tokens_tensor), [3]))[0]:
                nextToken = tf.slice(cur_tokens_tensor, [2], [1])
            else:
                nextToken = None

            wordCasingLookup = tf.reshape(self.get_alternative_tokens(tf.constant(curToken)), [-1])
            if tf.get_static_value(tf.equal(tf.size(wordCasingLookup), [0]))[0]:
                trueCasedTokens = tf.concat([trueCasedTokens, tf.constant(curToken)], 0)
            if tf.get_static_value(tf.equal(tf.size(wordCasingLookup), [1]))[0]:
                trueCasedTokens = tf.concat([trueCasedTokens, wordCasingLookup], 0)
            else:
                scores = tf.map_fn(lambda x: self.get_score(prevToken, x, nextToken),
                                   wordCasingLookup, dtype=tf.float32)
                maxElementIndx = tf.get_static_value(tf.argmax(scores))[0]
                trueVariant = tf.slice(wordCasingLookup, [maxElementIndx], [1])
                trueCasedTokens = tf.concat([trueCasedTokens, trueVariant], 0)

            tokens_tensor = tf.slice(tokens_tensor, [1], [-1])
            return tokens_tensor, trueCasedTokens

        res = tf.while_loop(condition,
                            lambda tokens_tensor, trueCasedTokens: body(tokens_tensor, trueCasedTokens),
                            [tokens_tensor, trueCasedTokens])                    
        res = res[1]
        res = tf.gather(res, tf.where(res != b''))

        return res

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

  File "/Users/ayushc/transcript-post-processor/truecaser.py", line 67, in truecase_tokenize
    res = tf_model.get_true_case(tokens_tensor)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
    *args, **kwds))
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2658, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in converted code:

    ./truecaser/truecaser_tf.py:160 body  *
        curToken = tf.get_static_value(tf.slice(cur_tokens_tensor, [1], [1]))[0]
    /usr/local/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2478 while_loop_v2
        return_same_structure=True)
    /var/folders/4j/sdgl7s9j3w51c50b5wy4v_p00000gn/T/tmpw4oa0q3i.py:21 body
        curToken = ag__.converted_call(tf.get_static_value, body_scope.callopts, (ag__.converted_call(tf.slice, body_scope.callopts, (cur_tokens_tensor, [1], [1]), None, body_scope),), None, body_scope)[0]

    TypeError: 'NoneType' object is not subscriptable

The code can be accessed here to look at the entire source code: https://github.com/ayushch3/truecaser_tf

The weights file can be accessed here: https://drive.google.com/file/d/1DpmsDYm-gzcwXCJT4sMCUXmtHGSxKiIR/view?usp=sharing

rmothukuru commented 4 years ago

This question is better asked on StackOverflow since it is not a bug or feature request. There is also a larger community that reads questions there.

ayushch3 commented 4 years ago

@rmothukuru Its actually a bug because the code works as it is when input is passed to the function, however adding the tf.function wrapper causes the code to fail, my guess is its something to do with running this in graph mode. I need to be able to export that function for tf-serving but this has remained a blocker

rmothukuru commented 4 years ago

@ayushch3, If the problem is with tf.function wrapper, then it is more of a Tensorflow issue than a Serving Issue. Please raise an Issue in Tensorflow Repository. Thanks!

ayushch3 commented 4 years ago

@rmothukuru I was able to resolve the issue with tf.function wrapper, once the model was exported, its still not compatible with tf serving:

https://github.com/tensorflow/serving/issues/1707

rmothukuru commented 4 years ago

@ayushch3, Can you please confirm if we can close this issue as the issue is resolved and as the subsequent error is being tracked in #1707. Thanks!

rmothukuru commented 4 years ago

@ayushch3 Can you please respond to the above comment. Thanks!