tensorflow / transform

Input pipeline framework
Apache License 2.0
984 stars 213 forks source link

Graph error when using TFHub Universal-sentence-encoder model #304

Open IzakMaraisTAL opened 1 year ago

IzakMaraisTAL commented 1 year ago

I want to apply the universal-sentence-encoder model in a TFX Transform preprocessing_fn.

My preprocessing function works in unit tests, but when I try to run it, it fails a graph validation step.

The Transform executor calls get_analyze_intput_columns() on the preprocessing function, which fails.

TFX version: 1.12.0 tensorflow-recommenders 0.7.3 Python version 3.7.10

Below is a minimal script that reproduces the error outside of TFX:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_transform as tft

embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")

@tf.function
def preprocess_use(tensorMap):
    return {"output": embed(tensorMap["input"])}

spec = {"input": tf.TensorSpec(shape=(None,), dtype=tf.string)}
tft.get_analyze_input_columns(preprocess_use, spec)

This gives the error:


Traceback (most recent call last):
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 179, in wrapper
    return func(self, tensor_or_op)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 749, in get_dependent_inputs
    self._graph_analyzer.analyze_tensor(component).dependent_sources)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 426, in analyze_tensor
    tf_utils.deref_tensor_or_op(current), parent_results)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 359, in _compute_analysis_result
    tensor_or_op, parent_analysis_results))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in _compute_analysis_results_for_func_attributes
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in <listcomp>
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 426, in analyze_tensor
    tf_utils.deref_tensor_or_op(current), parent_results)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 359, in _compute_analysis_result
    tensor_or_op, parent_analysis_results))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in _compute_analysis_results_for_func_attributes
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in <listcomp>
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 426, in analyze_tensor
    tf_utils.deref_tensor_or_op(current), parent_results)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 359, in _compute_analysis_result
    tensor_or_op, parent_analysis_results))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in _compute_analysis_results_for_func_attributes
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in <listcomp>
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 426, in analyze_tensor
    tf_utils.deref_tensor_or_op(current), parent_results)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 359, in _compute_analysis_result
    tensor_or_op, parent_analysis_results))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in _compute_analysis_results_for_func_attributes
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 320, in <listcomp>
    func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 417, in analyze_tensor
    parents = self._get_parents(tf_utils.deref_tensor_or_op(current))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 240, in _get_parents
    raise _UnexpectedTableError(tensor_or_op, func_graph_name)
tensorflow_transform.graph_tools._UnexpectedTableError: An unexpected initializable table was encountered (name: "text_preprocessor_1/hash_table"
op: "HashTableV2"
attr {
  key: "container"
  value {
    s: ""
  }
}
attr {
  key: "key_dtype"
  value {
    type: DT_STRING
  }
}
attr {
  key: "shared_name"
  value {
    s: "hash_table_1ad50cc5-00f6-4158-9996-2ed5369c9f0e_load_0_2_load_1"
  }
}
attr {
  key: "use_node_name_sharing"
  value {
    b: true
  }
}
attr {
  key: "value_dtype"
  value {
    type: DT_INT64
  }
}
experimental_debug_info {
  original_node_names: "text_preprocessor_1/hash_table"
}
)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tmp/tests/use_tf_function_bug.py", line 15, in <module>
    tft.get_analyze_input_columns(preprocess_use, spec)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/inspect_preprocessing_fn.py", line 70, in get_analyze_input_columns
    graph, structured_inputs, output_tensors)
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 802, in get_dependent_inputs
    dependent_inputs.update(graph_analyzer.get_dependent_inputs(output_tensor))
  File "/.venv/lib/python3.7/site-packages/tensorflow_transform/graph_tools.py", line 194, in wrapper
    ''.format(tensor_or_op, e.op, e.func_graph_name)) from e
ValueError: The tensor_or_op name: "NoOp"
op: "NoOp"
input: "^StatefulPartitionedCall"
attr {
  key: "_acd_function_control_output"
  value {
    b: true
  }
}
 depended on an initializable table (name: "text_preprocessor_1/hash_table"
op: "HashTableV2"
attr {
  key: "container"
  value {
    s: ""
  }
}
attr {
  key: "key_dtype"
  value {
    type: DT_STRING
  }
}
attr {
  key: "shared_name"
  value {
    s: "hash_table_1ad50cc5-00f6-4158-9996-2ed5369c9f0e_load_0_2_load_1"
  }
}
attr {
  key: "use_node_name_sharing"
  value {
    b: true
  }
}
attr {
  key: "value_dtype"
  value {
    type: DT_INT64
  }
}
experimental_debug_info {
  original_node_names: "text_preprocessor_1/hash_table"
}
) that is part of a tf.function graph (pruned), this is not supported. This may be a result of initializing a table in a tf.function

Is there a way to work around this restriction?

Previous closed bugs makes it clear that it is intended to work:

singhniraj08 commented 1 year ago

@IzakMaraisTAL, I was able to replicate this issue. @zoyahav, Can you please have a look? Thank you!

IzakMaraisTAL commented 1 year ago

Excellent, thanks for the feedback.