Closed khatchad closed 9 months ago
There is still a speed-up when using a tuple or list of tensors. Therefore, it should be considered as a "tensor-like" parameter.
This is very similar code to https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction in the TF docs
From running a Sequential Model with a tuple of tf.Tensor
, the execution time with tf.function
is 1.1418809699989652, and without tf.function
is 31.41282747400146.
For a simple example,
@tf.function
def maxim(x):
m = 0.
for i in x:
m = tf.maximum(m, i)
return m
print(maxim((tf.constant(5.5), tf.constant(6.5))))
print(maxim((tf.constant(8.5), tf.constant(9.5))))
print(maxim.pretty_printed_concrete_signatures())
When calling maxim(..)
two times as:
print(maxim((tf.constant(5.5), tf.constant(6.5))))
print(maxim((tf.constant(8.5), tf.constant(9.5))))
We get the output of:
maxim(x)
Args:
x: (<1>, <2>)
<1>: float32 Tensor, shape=()
<2>: float32 Tensor, shape=()
Returns:
float32 Tensor, shape=()
pretty_printed_concrete_signatures()
can be used to see all the traces. Therefore, we know that the function was only traced once. Therefore, the second call does not create a new graph.
If we do this with normal Python values as such:
@tf.function
def maxim(x):
m = 0.
for i in x:
m = tf.maximum(m, i)
return m
print(maxim((3, 2)))
print(maxim((4, 5)))
print(maxim.pretty_printed_concrete_signatures())
When calling maxim(..)
two times as:
print(maxim((3, 2)))
print(maxim((4, 5)))
We get the output of:
maxim(x=(3, 2))
Returns:
float32 Tensor, shape=()
maxim(x=(4, 5))
Returns:
float32 Tensor, shape=()
As we see in pretty_printed_concrete_signatures()
, the function was traced twice. Therefore, the second call does create a new graph.
Therefore, we can see a difference when using tf.Tensor
as a tuple/list. It does not re-trace unnecessarily.
For Python ordered containers such as list and tuple, etc., the type is parameterized by the types of their elements...
Consider the following function:
Called here:
https://github.com/ponder-lab/mead-baseline/blob/4411029050d5549b34c75fe205451d6a1e80d335/mead/api_examples/pretrain_paired_tf.py#L329
Currently, we don't consider this function as having a tensor parameter. But, it does have a tuple parameter that contains tensors (at least from the type hint; and there's some evidence from the calling context). Should this be considered a "tensor-like" parameter?