Closed WindQAQ closed 5 years ago
cc @seanpmorgan and @facaiy
@alextp Alexandre, could you take a look?
I think
Please correct me if I'm wrong :-)
@facaiy is correct. TF supports both graph-build-time shape checks and graph-run-time shape checks. In tf.function both types of checks might be useful, though tf.function graphs tend to have more static shapes than most manually built tf graphs, so you get more mileage out of the static checks.
Looking at the bugs I see a few false / confused statements being made, so I'd like to understand better what the actual issue here is before I can help.
@alextp, hi Alexandre, I have some explanations and questions here:
tf.debugging.*
part, I am very wondering how to use it inside tf.function
:class TestDebugging(tf.test.TestCase):
@tf.function
def foo(self, x):
tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")
y = x[2]
@tf.function
def bar(self, x):
with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
y = x[2]
def test_assert(self):
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "wrong shape"):
self.foo(tf.random.uniform(shape=(2,)))
def test_assert_v2(self):
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "wrong shape"):
self.bar(tf.random.uniform(shape=(2,)))
The test case above will fail because the statement of y = x[2]
will raise ValueError
, indicating that tf.debugging.*
could not block the following computation even if control_dependencies
is added. Is there anything I misuse?
For the pure if-else statement part, should we always check the static shape with tensor.shape
instead of tf.shape(tensor)
? I mean the following code snippet is a wrong use case (it will fail):
class TestPureIfElse(tf.test.TestCase):
@tf.function
def foo(self, x):
if tf.shape(x)[0] < 3:
raise ValueError("wrong shape")
def test_assert(self):
self.foo(tf.random.uniform(shape=(3,)))
Many thanks for the help!
What's happening in your two cases is that the ValueError is being raised at graph building time not graph run time, because in tf.function x has a static shape so we know that slicing it is invalid, but tf.shape returns a symbolic tensor and the assert is only evaluated at graph run time.
So to catch the x[2] during graph building time you need a static shape check as well as a dynamic shape check:
@tf.function
def bar(x):
if x.shape is not None and x.shape.ndims >= 1:
assert x.shape[0] >= 3
with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
y = x[2]
And you might wonder "what is the assert_greater_equal buying me, then", it buys you the case for which the shape is not known at graph build time. So for example this makes it fail:
@tf.function
def bar(x):
if x.shape is not None and x.shape.ndims >= 1 and x.shape[0] is not None:
assert x.shape[0] >= 3
with tf.control_dependencies([tf.debugging.assert_greater_equal(tf.shape(x)[0], 3, message="wrong shape")]):
y = x[2]
fn = bar.get_concrete_function(tf.TensorSpec(dtype=tf.float32, shape=[None]))
fn(tf.random.uniform(shape=(2,)))
Thank you for the information! I initially suppose the tf.function
will convert it, but it seems that It's quite a huge work to do shape checking inside tf.function
... So to conclude, we should always check both static and dynamic shape, right?
Edit: And because C++ offers static shape checking, we should also check dynamic shape in Python?
I think ideally both static and dynamic shape checks should be there, yes.
On Wed, May 22, 2019 at 10:21 AM Tzu-Wei Sung notifications@github.com wrote:
Thank you for the information! I initially suppose the tf.function will convert it, but it seems that It's quite a huge work to do shape checking inside tf.function... So to conclude, we should always check both static and dynamic shape, right?
Edit: And because C++ offers static shape checking, we should also check dynamic shape in Python?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/addons/issues/260?email_source=notifications&email_token=AAABHRKPSFLRIFB6DBQC7YTPWV6IDA5CNFSM4HORAJJKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODV7XG5A#issuecomment-494891892, or mute the thread https://github.com/notifications/unsubscribe-auth/AAABHRP2JVUIF2I7FPMUJT3PWV6IDANCNFSM4HORAJJA .
--
Thanks very much for the help Alexandre and Tzu-Wei thanks for writing out the distinct examples. I think we should include this information in our subpackage contribution guides or somewhere else convenient.
The build just failed yesterday due to some shape checking fragments in C++ code. https://source.cloud.google.com/results/invocations/f48f6f28-c9a4-4912-b8c5-336f17167183
Currently, there are three approaches that can check tensor's properties and raise exceptions to block the following computation in addons (migrating from core TF):
Pure
if-else
andraise
statement https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/image/dense_image_warp.py#L51-L53tf.debugging.*
https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/image/dense_image_warp.py#L70-L71OP_REQUIRES
in C++ https://github.com/tensorflow/addons/blob/d46dba1ef691f607a69e379fb3c0c2c16daec2fb/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc#L57-L59However, none of them are robust enough in
tf.function
and in test cases with@run_in_graph_and_eager_modes
. (https://github.com/tensorflow/addons/issues/138, https://github.com/tensorflow/addons/pull/257)When checking core TensorFlow, I find this commit on Apr 10: https://github.com/tensorflow/tensorflow/commit/4b4a39e7120b1c7744f9686bc6cce9363846d7e6#diff-68b5e47db1d9389c8d12852996845819
According to the doc, does it encourage us to use
tf.control_dependencies
insidetf.function
to do something like shape checking?