keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.09k stars 19.48k forks source link

(torch) Use pure-python implementation of `tree` when in dynamo context #18614

Open kiukchung opened 1 year ago

kiukchung commented 1 year ago

Motivation

any_symbolic_tensors is called on pretty much all ops and it internally uses tree.flatten() to check if any positional or keyword arguments to the op is a KerasTensor (e.g. symbolic tensor). torchdynamo skips tracing tree (presumably since it has C-bindings) therefore causing graph-breaks at each op. This results in poor jitted performance for the pytorch backend since the graph-breaks occur between each op and we lose opportunities for any significant compiler optimizations (e.g. operator fusion).

See https://github.com/keras-team/keras/pull/18569 for more details.

Proposal

NOTE: making any_symbolic_tensors() does not guarantee everything in keras will be dynamo compatible. Once we fix this other issues may arise.

  1. Povide a dynamo traceable pure-python version of tree.flatten() and use that instead of tree.flatten() to prevent graph-breaks at any_symbolic_tensors().
  2. If 1) is not enough, that is, we now observe graph breaks (albeit not as frequent) due to other usages of tree.* then (as suggested by @fchollet in https://github.com/keras-team/keras/pull/18569) we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree.* with keras.utils.tree.*.

My suggestion is to first to 1), then see if 2) is needed as 2) is a bigger change that we may not actually need.

fchollet commented 1 year ago

Do we have a way to tell when we're in a Dynamo context?

AakashKumarNain commented 1 year ago

we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree. with keras.utils.tree..

Should we start using Optree instead?

mattdangerw commented 1 year ago

If we add a keras.utils.tree let's make sure to export it. We are using dm-tree in KerasNLP downstream of Keras, and if we have a torch friendly nested solution, it would be great to be able to leverage.

ASEM000 commented 11 months ago

+1 for optree.

haifeng-jin commented 6 months ago

Hi @james77777778 ,

Seems optree is not implemented with Python only. It uses some C, too. We found it is still not compatible with the torch dynamo.

Is my understanding correct? What was the reason that we want to swap dm-tree to optree, please?

We may need this info for future work of supporting torch dynamo. Thanks!

james77777778 commented 6 months ago

Is my understanding correct? What was the reason that we want to swap dm-tree to optree, please?

Actually, I just picked the item from: #18442 (a few months ago)

In the PR, I didn't achieve a significant speed-up by replacing dm-tree with optree for torch backend: https://github.com/keras-team/keras/pull/19306#issuecomment-1996541519

It's a bit strange for me as well, considering that optree has been integrated into torch.

Refs:

Perhaps we need a completely pure python implementation for dynamo?