Closed ageron closed 5 years ago
@alextp @alexbw @brilee @jmd101 @aaandrewww
This is a very good issue that you raised.
Before I elaborate, let me add a use case to your suite. It is less awkward than range(tf.constant(x))
and also gives you more certainty about its behavior:
@tf.function
def foo4(x):
for i in tf.range(10):
x = x + 1
return x
At the core of this confusing behavior is the need to support the use case of building a multi-layer neural network:
for i in range(num_layers):
do_something_that_cannot_be_done_in_a_tf_while()
I agree that this automatic behavior can lead to surprising behavior, as you pointed out. For example writing a training loop for i in range(num_steps)
will most likely not have the desired result.
A robust solution is unobvious to me, so I'm opening the thread for discussion on a few alternatives, off the top of my head:
Add a check for non-staged loops in which, if the number of steps is too large, or the resulting graph increases too much, output a warning to the user, in the form ('The loop x is not running in TensorFlow, is this what you intend?').
Output a warning if the loop is unstaged, but could be (e.g. it doesn't create any resources).
More aggressively run the loop in TF, perhaps by attempting to always run in TF and fall back to unrolling only if we detect that e.g. a resource was created.
Thoughts?
@jmd1011
Copying the correct ID
I think that in general, unless we can infer the user's intent with high probability, trying to guess is often worse than just providing a way for the user to pick one or the other. The documentation does mention the range() vs. tf.range() thing, but we could maybe make that more prominent. The range/tf.range toggle seems to me like the easiest way for the user to indicate intent.
Hi, thanks for the interesting feedback. So if I understand correctly, for i in range(...)
will not be converted into a tf.while()
block, but for i in tf.range(...)
will (and so will for i in range(tensor)
)?
Ahh, makes more sense now.
@brilee, could you please point me to the part of the documentation that mentions this? I looked at the 2.0-preview documentation for tf.autograph.to_code()
and to_graph()
, as well as tf.function()
, but I see no mention of this.
I agree with your point about guessing vs giving the user the choice. I think that if the range
/tf.range
choice is presented more prominently in the documentation, it should be understandable. However, if tf.range
is the most common option (is it?), or if range
is the most dangerous one, then perhaps it should be the reverse: range
would be converted to a tf.while
block, while tf.python_range
(or something) would be treated like a regular python loop. Just dropping the idea here, not sure it's any good.
I'm looking at https://github.com/tensorflow/docs/blob/master/site/en/r2/guide/autograph.ipynb and realizing that the latest version doesn't mention this range/tf.range thing. I'll add a note to that colab.
Indeed, the bug originated from the fact that range
was not expected to represent "don't convert" and is the dangerous one. One possible fix is to output a clear error when range is called with a tensor argument.
However, this still doesn't solve the problem of for i in range(1000): train()
. Always converting range and error out when that's not possible because e.g. variables (case in which we'd require something like python_range
) would in turn be guaranteed to be safe.
I just discussed with @alextp about this. The problem is that the default behavior (range
vs tf.range
, python args vs tensor args, etc.) is quite surprising to new users (as I experienced last week when giving a TF2 course). Indeed, pretty much everyone expected foo(5)
to behave like foo(tf.constant(5))
when using autograph=True
. Similarly, everyone expected range
to be included in the graph when autograph=True
. How can we make the default behavior less surprising?
Perhaps one solution could be to treat python values as if they were wrapped in tf.constant(...)
, by default, for example:
@tf.function
def foo(x):
for i in tf.range(10):
x = x + 1
return x
foo(5) # okay, equivalent to `foo(tf.constant(5))`
foo(tf.constant(5)) # okay, equivalent
If you set convert_args_to_tensors=False
when calling tf.function
(the option name is to be defined, of course), then python arguments get treated like today:
@tf.function(convert_args_to_tensors=False)
def foo(x):
for i in tf.range(10):
x = x + 1
return x
foo(5) # okay, but NOT equivalent to `foo(tf.constant(5))`
foo(tf.constant(5)) # okay
Alternatively, python values could simply be rejected by default, with a helpful error message:
@tf.function
def foo(x):
for i in tf.range(10):
x = x + 1
return x
foo(5) # ERROR: "All arguments must be tensors. If you need python args, please set `support_python_args=True`."
foo(tf.constant(5)) # okay
So if a user wants to support python arguments to their function, they must set support_python_args=True
when calling tf.function
(again, this option name is to be defined):
@tf.function(support_python_args=True)
def foo(x):
for i in tf.range(10):
x = x + 1
return x
foo(tf.constant(5)) # okay
foo(5) # okay, since `support_python_args=True`, but NOT equivalent to `foo(tf.constant(5))`
Yet another solution could be to reject python args by default, but let users pass a special PythonArgumentSpec
to indicate that a specific argument is expected to be a python arg:
@tf.function(input_signature=[tf.PythonArgumentSpec(name="x")])
def foo(x):
for i in tf.range(10):
x = x + 1
return x
foo(tf.constant(5)) # ERROR: the argument is not expected to be a tensor
foo(5) # okay
Regarding range
vs tf.range
, perhaps range
should default to the same behavior as tf.range
:
@tf.function
def foo(x):
for i in range(10): # equivalent to tf.range(10), by default there's no difference
x = x + 1
return x
But if you want range
to behave like a regular python range
(i.e., not be included in the graph), then set static_range=True
(again, arg name to be defined):
@tf.function(static_range=True)
def foo(x):
for i in range(10): # NOT equivalent to tf.range(10)
x = x + 1
return x
Alternatively, perhaps we should always let range
be equivalent to tf.range
, but then we could add a tf.static_range
for whenever you need to run a static loop (i.e., not added to the graph)?
These are just ideas, they can certainly be improved a lot. Again, the end goal is to make the default behavior unsurprising to new users (i.e., range
is added to the graph when using autograph=True
, and foo(5)
behaves like foo(tf.constant(5))
), while making it possible to easily opt-in for python values as arguments and static range behavior.
Wdyt?
Thank you Aurélien for your input. The insights you gathered are most valuable. Please continue to send us any new data points you might gather.
It would seem that new users tend to have different expectations compared to "veteran" developers who wrote lots of TF graph code.
One thing that becomes obvious is that mechanisms which reliably avoid ambiguity would be useful regardless of what the defaults would be.
An extreme version of such a mechanism would be an "all-graph" mode, in the lines of what you suggested: @tf.function(autograph=STRICT)
where everything ran in graph, and doing anything outside the graph required special overrides.
As a side note, we recently pushed a change where a construct like range(tf.constant(n))
would be unsupported and raise an error. Although, that still leaves us with unexpected behavior in the case when the argument to the function was a Python value, or when users wrote range(n)
out of sheer convenience, fully expecting the loop to be in-graph. The hope was that this would drive the habit of using either range
to always statically unroll, or tf.range
to always run in graph.
Hi @mdanatg ,
It's great that range(tf.constant(n))
is now rejected! :)
Would autograph=STRICT
be the default? If so, I guess it would mean that by default range()
would act like tf.range()
(great!). Would it also mean that foo(5)
would be rejected, or better, behave like foo(constant(5))
? I would vote for that!
People who want the current behavior could specify something like autograph=MIXED
, right?
Actually, thinking about this some more, I should probably mention that the most surprising behavior (to me and to the participants in my course) was, by far, the fact that foo(n)
traces the function once for each distinct value of n
(when n
is a python value). It is not clear to me how often this behavior will be needed. I didn't run into any actual use case yet. So I would argue that the default should be to just reject python arguments, or at least convert them to tensors automatically (using tf.constant()
).
On the other hand, it is true that the use case of building a model (e.g., using the subclassing API) composed of multiple similar layers or submodels is quite common, and it does require using a static loop (i.e., range()
), rather than a dynamic loop (i.e., tf.range()
).
We'd have to think very carefully about adding modes like STRICT
and MIXED
. We want to avoid adding complexity to the interface unless absolutely necessary, so it might take a bit to decide whether we move from idea to implementation. The biggest unanswered question, as you point out, is what to do with things like hyperparameters (e.g. train=True/False
, num_layers
, etc.). Perhaps those would have to be marked explicitly.
The idea of autoboxing literals is interesting. What I wonder is, what should be the behavior of a constant created inside the function (e.g. n = 3
). Should that be autoboxed, consistent with the function argument? That might not be feasible to do for @tf.function
alone.
@alextp I think the possibility of autoboxing is a question for the interface of @tf.function
, so it should be discussed in that context.
In terms of literals, there's a variety of ways that the graph can change with different python literals - num_layers, mode=TRAINING, unroll_steps, etc.
For the autoboxing question... probably the least surprising way to go about things is to copy NumPy's autoboxing semantics. But I don't know that NumPy is a shining star for unsurprising behavior - I discovered just now that [1, 2, 3] * np.array([1, 2, 3]) == array([1, 4, 9])
.
@ageron The decision to retrace on python scalars was to allow people to write code like
@tf.function
def apply_resnet(n_layers, image): ...
We take this conservative retracing approach because (1) the set of things you're allowed to do with a python scalar is strictly bigger than the set of things you're allowed to do with a tf scalar (see here for example using it to set the number of layers of a resnet), and (2) for uniformity between argument passing and closure capturing, i.e. there is no way to special-case when a tf.function uses a python scalar variable defined outside the function and make that pretend to be a tf scalar.
I am curious though because I tried to make it sure that retracing is mostly invisible, and will only have negative performance effects. So what kind of behavior did your students witness that was confusing to them?
(we have a plan to diminish the frequency of retraces wrt shapes which I think can be extended to scalars, but it does have a side effect of no longer letting us use a python dictionary to do trace lookup, so I want to defer adding it to see if someone else comes up with a better way)
Thanks @alextp , I understand your constraints a bit better, in particular the need for uniformity between argument passing and closure capturing.
Regarding the participants (and myself), the surprise came from examining the graphs generated by various functions, and the number of times the functions were traced. Indeed, the outputs were the ones we expected, it's just a matter of performance, and possibly running out of RAM. But there can be a x100 performance difference or more, so it's not a small detail. It's like deque
vs list
: sometimes knowing how something is implemented is critical.
Moreover, it may not be trivial to debug. I fear that many users will write range()
instead of tf.range()
, or they will call foo(5)
instead of foo(tf.constant(5))
. I'm just trying to think of ways to reduce that risk (other than having super clear warnings in the documentation and good code examples). Perhaps there's no way around it, but I think it's worth exploring.
Thanks you all for the useful discussion. This contributed to formulating a set of semantics consistent across autograph, which are summarized here:
A few concerns remain around expectations for first-time users, but I'm hopeful those can be addressed by documentation and tutorials.
System information
Describe the current behavior
tf.function
converts almost identical functions into completely different graphs. It seems like a bug, but perhaps it's just a bit too complicated for me. If it's working as expected, then I think the documentation really needs to be expanded, with detailed examples and clear guidelines.Describe the expected behavior All the following functions should return almost identical graphs.
Code to reproduce the issue
Other info / logs
Below is the output of this program. Notice that:
tf.while_loop
. Imagine a loop with 10000 iterations, it would just blow up.