tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 613 forks source link

Raise exceptions based on tensor's properties #260

Closed WindQAQ closed 5 years ago

WindQAQ commented 5 years ago

The build just failed yesterday due to some shape checking fragments in C++ code. https://source.cloud.google.com/results/invocations/f48f6f28-c9a4-4912-b8c5-336f17167183

Currently, there are three approaches that can check tensor's properties and raise exceptions to block the following computation in addons (migrating from core TF):

  1. Pure if-else and raise statement https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/image/dense_image_warp.py#L51-L53

  2. tf.debugging.* https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/image/dense_image_warp.py#L70-L71

  3. OP_REQUIRES in C++ https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc#L57-L59

However, none of them are robust enough in tf.function and in test cases with @run_in_graph_and_eager_modes. (https://github.com/tensorflow/addons/issues/138, https://github.com/tensorflow/addons/pull/257)


When checking core TensorFlow, I find this commit on Apr 10: https://github.com/tensorflow/tensorflow/commit/4b4a39e7120b1c7744f9686bc6cce9363846d7e6#diff-68b5e47db1d9389c8d12852996845819

Op raising InvalidArgumentError unless x is all negative. This can be used with tf.control_dependencies inside of tf.functions to block followup computation until the check has executed.

According to the doc, does it encourage us to use tf.control_dependencies inside tf.function to do something like shape checking?

WindQAQ commented 5 years ago

cc @seanpmorgan and @facaiy

facaiy commented 5 years ago

@alextp Alexandre, could you take a look?

facaiy commented 5 years ago

I think

Please correct me if I'm wrong :-)

alextp commented 5 years ago

@facaiy is correct. TF supports both graph-build-time shape checks and graph-run-time shape checks. In tf.function both types of checks might be useful, though tf.function graphs tend to have more static shapes than most manually built tf graphs, so you get more mileage out of the static checks.

Looking at the bugs I see a few false / confused statements being made, so I'd like to understand better what the actual issue here is before I can help.

WindQAQ commented 5 years ago

@alextp, hi Alexandre, I have some explanations and questions here:

  1. I misunderstood the reason of failed build this time, so the C++ is now OK for me. Sorry about that.
  2. For tf.debugging.* part, I am very wondering how to use it inside tf.function:
class TestDebugging(tf.test.TestCase):

    @tf.function
    def foo(self, x):
        tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")
        y = x[2]

    @tf.function
    def bar(self, x):
        with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
            y = x[2]

    def test_assert(self):
        with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "wrong shape"):
            self.foo(tf.random.uniform(shape=(2,)))

    def test_assert_v2(self):
        with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "wrong shape"):
            self.bar(tf.random.uniform(shape=(2,)))

The test case above will fail because the statement of y = x[2] will raise ValueError, indicating that tf.debugging.* could not block the following computation even if control_dependencies is added. Is there anything I misuse?

  1. For the pure if-else statement part, should we always check the static shape with tensor.shape instead of tf.shape(tensor)? I mean the following code snippet is a wrong use case (it will fail):

    class TestPureIfElse(tf.test.TestCase):
    
    @tf.function
    def foo(self, x):
        if tf.shape(x)[0] < 3:
            raise ValueError("wrong shape")
    
    def test_assert(self):
        self.foo(tf.random.uniform(shape=(3,)))

Many thanks for the help!

alextp commented 5 years ago

What's happening in your two cases is that the ValueError is being raised at graph building time not graph run time, because in tf.function x has a static shape so we know that slicing it is invalid, but tf.shape returns a symbolic tensor and the assert is only evaluated at graph run time.

So to catch the x[2] during graph building time you need a static shape check as well as a dynamic shape check:

@tf.function
def bar(x):
    if x.shape is not None and x.shape.ndims >= 1:
      assert x.shape[0] >= 3
    with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
        y = x[2]
alextp commented 5 years ago

And you might wonder "what is the assert_greater_equal buying me, then", it buys you the case for which the shape is not known at graph build time. So for example this makes it fail:

@tf.function
def bar(x):
    if x.shape is not None and x.shape.ndims >= 1 and x.shape[0] is not None:
      assert x.shape[0] >= 3
    with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
        y = x[2]

fn = bar.get_concrete_function(tf.TensorSpec(dtype=tf.float32, shape=[None]))
fn(tf.random.uniform(shape=(2,)))
WindQAQ commented 5 years ago

Thank you for the information! I initially suppose the tf.function will convert it, but it seems that It's quite a huge work to do shape checking inside tf.function... So to conclude, we should always check both static and dynamic shape, right?

Edit: And because C++ offers static shape checking, we should also check dynamic shape in Python?

alextp commented 5 years ago

I think ideally both static and dynamic shape checks should be there, yes.

On Wed, May 22, 2019 at 10:21 AM Tzu-Wei Sung notifications@github.com wrote:

Thank you for the information! I initially suppose the tf.function will convert it, but it seems that It's quite a huge work to do shape checking inside tf.function... So to conclude, we should always check both static and dynamic shape, right?

Edit: And because C++ offers static shape checking, we should also check dynamic shape in Python?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/addons/issues/260?email_source=notifications&email_token=AAABHRKPSFLRIFB6DBQC7YTPWV6IDA5CNFSM4HORAJJKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODV7XG5A#issuecomment-494891892, or mute the thread https://github.com/notifications/unsubscribe-auth/AAABHRP2JVUIF2I7FPMUJT3PWV6IDANCNFSM4HORAJJA .

--

seanpmorgan commented 5 years ago

Thanks very much for the help Alexandre and Tzu-Wei thanks for writing out the distinct examples. I think we should include this information in our subpackage contribution guides or somewhere else convenient.