iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.8k stars 604 forks source link

Preserve input/output signatures from the TensorFlow frontend #7194

Open phoenix-meadowlark opened 3 years ago

phoenix-meadowlark commented 3 years ago

Currently the TensorFlow frontend implicitly folds empty tuples, lists and dicts out of compiled functions' input and output signatures. In the case of empty arguments and return values, this changes the arity of the function or the number of values it returns. This produces surprising errors, and limits the ease with which IREE can be incorporated into an already existing pipeline.

I created a test below which should cover a small number of the relevant signatures.

import iree.compiler.tf
import iree.runtime
from iree.tf.support import tf_test_utils
import tensorflow as tf

class InputOutputModule(tf_test_utils.TestModule):

  @tf_test_utils.tf_function_unit_test(input_signature=[[]])
  def empty_list_arg(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      [tf.TensorSpec([4]), []],
  ])
  def empty_list_nested(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      [[]],
      tf.TensorSpec([4]),
      [],
  ])
  def empty_list_multi(self, a, b, c):
    return a, b, c

  @tf_test_utils.tf_function_unit_test(input_signature=[tuple()])
  def empty_tuple_arg(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      (tf.TensorSpec([4]), tuple()),
  ])
  def empty_tuple_nested(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      tuple(tuple()),
      tf.TensorSpec([4]),
      tuple(),
  ])
  def empty_tuple_multi(self, a, b, c):
    return a, b, c

  @tf_test_utils.tf_function_unit_test(input_signature=[{}])
  def empty_dict_arg(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      {
          'tensor': tf.TensorSpec([4]),
          'maybe_empty_state': {},
      },
  ])
  def empty_dict_nested(self, a):
    return a

  @tf_test_utils.tf_function_unit_test(input_signature=[
      {
          'empty': {}
      },
      tf.TensorSpec([4]),
      {},
  ])
  def empty_dict_multi(self, a, b, c):
    return a, b, c

class InputOutputTest(tf_test_utils.TracedModuleTestCase):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._modules = tf_test_utils.compile_tf_module(InputOutputModule)

def main(argv):
  del argv  # Unused
  if hasattr(tf, 'enable_v2_behavior'):
    tf.enable_v2_behavior()
  InputOutputTest.generate_unit_tests(InputOutputModule)
  tf.test.main()
allieculp commented 2 years ago

Routing to @jpienaar to asses and add priority if needed.

allieculp commented 2 years ago

@not-jenni Can you help review and prioritize this on frontend rotation this week?

rsuderman commented 2 years ago

Prioritzed as P3 as TensorFlow is lower priority to other frontends.