Open thegodone opened 4 months ago
When Tensorflow performance sucks, these are the usual culprits:
scaled_dot_product_attention
which may leverage highly optimized kernels (e.g. FlashAttention). Sadly, there is no such thing in Tensorflow, so Nobuco computes attention with a naive algorithm.I use as much as I can tf.keras instead of keras import
I compare the two code speeds:
def converter_scaled_dot_product_attention1(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
D = tf.shape(query)[-1]
if scale is None:
scale = tf.cast(D, query.dtype) ** -0.5
# Corby's numerically more stable attention
# See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
s_scale = tf.cast(tf.sqrt(scale), query.dtype)
query = query * s_scale
key = key * s_scale
sim = tf.matmul(query, key, transpose_b=True)
if attn_mask is not None:
sim += attn_mask * -1e9
elif is_causal:
L = tf.shape(query)[-2]
S = tf.shape(key)[-2]
causal_mask = tf.linalg.band_part(tf.ones((L, S)), -1, 0)
sim = sim * causal_mask + (1.0 - causal_mask) * -1e9
attn = tf.nn.softmax(sim, axis=-1)
if dropout_p>0:
attn = Dropout(dropout_p)(attn)
return tf.matmul(attn, value)
return func
and your code (little modified):
def tril(h, w):
y = tf.range(0, h)[:, None]
x = tf.range(0, w)[None, :]
return y >= x
def converter_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
D = tf.shape(query)[-1]
if scale is None:
scale = tf.cast(D, query.dtype) ** -0.5
# Corby's numerically more stable attention
# See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
s_scale = tf.cast(tf.sqrt(scale), query.dtype)
query = query * s_scale
key = key * s_scale
sim = query @ tf.experimental.numpy.swapaxes(key, -2, -1)
if attn_mask is not None:
sim = tf.where(attn_mask, sim, float("-inf"))
elif is_causal:
L = tf.shape(query)[-2]
S = tf.shape(key)[-2]
causal_mask = tril(L, S)
sim = tf.where(causal_mask, sim, float("-inf"))
attn = tf.nn.softmax(sim, axis=-1)
attn = tf.keras.layers.Dropout(dropout_p)(attn)
return attn @ value
return func
I got almost same speed with a very little improvement in version "1"
@thegodone One important thing I almost forgot about: Tensorflow really hates dynamic tensor shapes. To infer language models with varying context length, you should do input padding (see this example and the accompanying issue).
Turns out, the inference is much faster if the Keras model is exported as SavedModel
artifact:
keras_model.export(model_path)
saved_model = tf.saved_model.load(model_path)
saved_model.serve(inputs)
nice catch, look like there are optimization during the savedmodel. I will try that thanks a lot
indeed almost all is faster now, except the "first line" built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute this is strange:
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.6295859813690186
Generator: tensor([[-29.3162, -8.5767, -8.5767, -8.1641, -8.7406, -8.5813, -8.7283,
-8.8958, -7.8562, -8.5062, -8.2777, -8.4021, -8.7413, -8.6585,
-8.9101, -8.7277, -8.8235, -7.7042, -8.4626, -7.2766, -0.8226,
-6.3867, -8.3603, -7.5078, -8.5743, -8.7478, -4.1617, -1.3617,
-7.9011, -7.7502, -5.9109, -8.5652, -1.3151, -8.6553, -8.7373,
-8.6551, -8.6658, -8.5221, -4.9030, -7.4738, -8.4073, -8.0457]],
grad_fn=<SelectBackward0>)
****************************************************************************************************
Keras Inference Time: 2.041451930999756
Generator: tf.Tensor(
[[-29.316212 -8.576733 -8.576746 -8.164137 -8.740633 -8.581325
-8.728261 -8.89583 -7.85616 -8.506172 -8.277732 -8.402089
-8.741335 -8.658456 -8.910144 -8.727736 -8.823455 -7.704239
-8.462603 -7.2765865 -0.8225823 -6.386689 -8.360339 -7.5077925
-8.574331 -8.747827 -4.161747 -1.3616791 -7.901073 -7.7501645
-5.910877 -8.565204 -1.3150749 -8.655296 -8.737296 -8.655108
-8.665803 -8.522052 -4.9029703 -7.4737854 -8.407324 -8.04574 ]], shape=(1, 42), dtype=float32)
****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16ef13880>
Traceback (most recent call last):
File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
Traceback (most recent call last):
File "/Users/tgg/Github/atr_igor/testkeras2_savedmodel.py", line 113, in <module>
profiler = cProfile.Profile()
^^^^^^^^
NameError: name 'cProfile' is not defined
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.7369470596313477
558604 function calls (471804 primitive calls) in 0.737 seconds
Ordered by: internal time
List reduced from 76 to 25 due to restriction <25>
ncalls tottime percall cumtime percall filename:lineno(function)
9700 0.276 0.000 0.276 0.000 {built-in method torch._C._nn.linear}
124600/111900 0.053 0.000 0.560 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
1800 0.046 0.000 0.046 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f40f40}
3200 0.039 0.000 0.039 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f41940}
1 0.030 0.030 0.737 0.737 /Users/tgg/Github/atr_igor/testkeras2_savedmodel.py:80(profile_pytorch_inference)
3600 0.029 0.000 0.029 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf740}
25200/500 0.022 0.000 0.699 0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
50400 0.020 0.000 0.020 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1696(__getattr__)
9600 0.020 0.000 0.020 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edda80}
1800 0.017 0.000 0.411 0.000 /Users/tgg/Github/atr_igor/transformer.py:66(forward)
25200/500 0.014 0.000 0.698 0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:241(forward)
3200 0.013 0.000 0.013 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f23600}
3200 0.012 0.000 0.109 0.000 /Users/tgg/Github/atr_igor/transformer.py:47(forward)
9700 0.011 0.000 0.297 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/linear.py:115(forward)
25200/500 0.011 0.000 0.699 0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
149800 0.010 0.000 0.010 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:212(is_tracing_enabled)
1800 0.009 0.000 0.121 0.000 /Users/tgg/Github/atr_igor/transformer.py:77(attention)
9000 0.009 0.000 0.009 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f432e0}
4200 0.008 0.000 0.014 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/functional.py:1279(dropout)
7200 0.008 0.000 0.008 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f584a0}
5000 0.008 0.000 0.008 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ef0e00}
3000 0.007 0.000 0.664 0.000 /Users/tgg/Github/atr_igor/transformer.py:26(forward)
3400 0.006 0.000 0.006 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf880}
1800 0.005 0.000 0.005 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f22fc0}
1800 0.005 0.000 0.005 0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ede340}
****************************************************************************************************
Keras Inference Time: 2.2881689071655273
1447802 function calls (1311645 primitive calls) in 2.286 seconds
Ordered by: internal time
List reduced from 1591 to 25 due to restriction <25>
ncalls tottime percall cumtime percall filename:lineno(function)
100 1.671 0.017 1.671 0.017 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
400 0.128 0.000 0.129 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py:70(convert_to_eager_tensor)
667 0.030 0.000 0.030 0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_FinishOperation}
168 0.029 0.000 0.029 0.000 {method '_numpy_internal' of 'tensorflow.python.framework.ops.EagerTensor' objects}
10721/155 0.016 0.000 0.096 0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:488(generic_visit)
103263/30224 0.015 0.000 0.033 0.000 {built-in method builtins.hash}
178603/178541 0.012 0.000 0.032 0.000 {built-in method builtins.isinstance}
113900 0.010 0.000 0.011 0.000 {built-in method builtins.getattr}
5098/199 0.009 0.000 0.021 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/ast_util.py:33(copy)
34870 0.008 0.000 0.024 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor.py:894(__hash__)
6316/44 0.008 0.000 0.090 0.002 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/transformer.py:417(visit)
16072/82 0.008 0.000 0.123 0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:414(visit)
59573 0.007 0.000 0.007 0.000 {built-in method builtins.hasattr}
26280 0.007 0.000 0.012 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/anno.py:130(hasanno)
46432 0.007 0.000 0.009 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:255(iter_fields)
1098 0.004 0.000 0.006 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/typing.py:1911(_get_protocol_attrs)
24987 0.004 0.000 0.005 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/enum.py:1230(__hash__)
34870 0.004 0.000 0.005 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor_shape.py:1508(__hash__)
667 0.004 0.000 0.043 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:959(_create_c_op)
231 0.004 0.000 0.004 0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_GraphCopyFunction}
3502 0.004 0.000 0.009 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
797 0.003 0.000 0.004 0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_OperationGetAttrValueProto}
2030/98 0.003 0.000 0.013 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/gast/astn.py:17(generic_visit)
200 0.003 0.000 0.006 0.000 /Users/tgg/Github/atr_igor/data.py:237(pad_pack)
7339 0.003 0.000 0.003 0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py:264(__eq__)
****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16c817880>
Traceback (most recent call last):
File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable
probably deserves mentioning on readme that performances suffers until model is saved.
After conversion I have issue on speed:
What is also strange is the number of operations between torch and keras.