Open kiukchung opened 1 year ago
Do we have a way to tell when we're in a Dynamo context?
we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree. with keras.utils.tree..
Should we start using Optree instead?
If we add a keras.utils.tree
let's make sure to export it. We are using dm-tree
in KerasNLP downstream of Keras, and if we have a torch friendly nested solution, it would be great to be able to leverage.
+1 for optree.
Hi @james77777778 ,
Seems optree
is not implemented with Python only. It uses some C, too.
We found it is still not compatible with the torch dynamo.
Is my understanding correct? What was the reason that we want to swap dm-tree to optree, please?
We may need this info for future work of supporting torch dynamo. Thanks!
Is my understanding correct? What was the reason that we want to swap dm-tree to optree, please?
Actually, I just picked the item from: #18442 (a few months ago)
In the PR, I didn't achieve a significant speed-up by replacing dm-tree
with optree
for torch backend:
https://github.com/keras-team/keras/pull/19306#issuecomment-1996541519
It's a bit strange for me as well, considering that optree
has been integrated into torch.
Refs:
Perhaps we need a completely pure python implementation for dynamo?
Motivation
any_symbolic_tensors
is called on pretty much all ops and it internally usestree.flatten()
to check if any positional or keyword arguments to the op is aKerasTensor
(e.g. symbolic tensor). torchdynamo skips tracingtree
(presumably since it has C-bindings) therefore causing graph-breaks at each op. This results in poor jitted performance for the pytorch backend since the graph-breaks occur between each op and we lose opportunities for any significant compiler optimizations (e.g. operator fusion).See https://github.com/keras-team/keras/pull/18569 for more details.
Proposal
tree.flatten()
and use that instead oftree.flatten()
to prevent graph-breaks atany_symbolic_tensors()
.tree.*
then (as suggested by @fchollet in https://github.com/keras-team/keras/pull/18569) we need to create akeras.utils.tree
that uses pure-python implementations when in dynamo context and replace usages oftree.*
withkeras.utils.tree.*
.My suggestion is to first to 1), then see if 2) is needed as 2) is a bigger change that we may not actually need.