ponder-lab / Hybridize-Functions-Refactoring

Refactorings for optimizing imperative TensorFlow clients for greater efficiency.
Eclipse Public License 2.0
0 stars 0 forks source link

Consider functions with containers of tensors as having "tensor" parameters #283

Closed khatchad closed 9 months ago

khatchad commented 10 months ago

Consider the following function:

@tf.function
def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
    """Runs across multiple replicas and aggregates the results.

    :param inputs:
    :return:
    """
    per_replica_loss = strategy.run(_replicated_train_step, args=(inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

Called here:

https://github.com/ponder-lab/mead-baseline/blob/4411029050d5549b34c75fe205451d6a1e80d335/mead/api_examples/pretrain_paired_tf.py#L329

Currently, we don't consider this function as having a tensor parameter. But, it does have a tuple parameter that contains tensors (at least from the type hint; and there's some evidence from the calling context). Should this be considered a "tensor-like" parameter?

tatianacv commented 9 months ago

There is still a speed-up when using a tuple or list of tensors. Therefore, it should be considered as a "tensor-like" parameter.

tatianacv commented 9 months ago

This is very similar code to https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction in the TF docs

tatianacv commented 9 months ago

From running a Sequential Model with a tuple of tf.Tensor, the execution time with tf.function is 1.1418809699989652, and without tf.function is 31.41282747400146.

tatianacv commented 9 months ago

For a simple example,

@tf.function
def maxim(x):
  m = 0.
  for i in x:
    m = tf.maximum(m, i)
  return m

print(maxim((tf.constant(5.5), tf.constant(6.5))))
print(maxim((tf.constant(8.5), tf.constant(9.5))))

print(maxim.pretty_printed_concrete_signatures())

When calling maxim(..) two times as:

print(maxim((tf.constant(5.5), tf.constant(6.5))))
print(maxim((tf.constant(8.5), tf.constant(9.5))))

We get the output of:

maxim(x)
  Args:
    x: (<1>, <2>)
      <1>: float32 Tensor, shape=()
      <2>: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

pretty_printed_concrete_signatures() can be used to see all the traces. Therefore, we know that the function was only traced once. Therefore, the second call does not create a new graph.

If we do this with normal Python values as such:

@tf.function
def maxim(x):
  m = 0.
  for i in x:
    m = tf.maximum(m, i)
  return m

print(maxim((3, 2)))
print(maxim((4, 5)))

print(maxim.pretty_printed_concrete_signatures())

When calling maxim(..) two times as:

print(maxim((3, 2)))
print(maxim((4, 5)))

We get the output of:

maxim(x=(3, 2))
  Returns:
    float32 Tensor, shape=()

maxim(x=(4, 5))
  Returns:
    float32 Tensor, shape=()

As we see in pretty_printed_concrete_signatures(), the function was traced twice. Therefore, the second call does create a new graph.

Therefore, we can see a difference when using tf.Tensor as a tuple/list. It does not re-trace unnecessarily.

khatchad commented 9 months ago

https://www.tensorflow.org/guide/function#rules_of_tracing

khatchad commented 9 months ago

For Python ordered containers such as list and tuple, etc., the type is parameterized by the types of their elements...