AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
272 stars 17 forks source link

speed issue after conversion #54

Open thegodone opened 4 months ago

thegodone commented 4 months ago

After conversion I have issue on speed:

PyTorch Inference Time: 0.7236480712890625
         558606 function calls (471806 primitive calls) in 0.723 seconds

   Ordered by: internal time
   List reduced from 77 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     9700    0.268    0.000    0.268    0.000 {built-in method torch._C._nn.linear}
124600/111900    0.056    0.000    0.552    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
     3200    0.040    0.000    0.040    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca41da0}
     3600    0.030    0.000    0.030    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9dfba0}
        1    0.028    0.028    0.723    0.723 /Users/tgg/Github/atr_igor/testkeras5.py:73(profile_pytorch_inference)
     1800    0.027    0.000    0.027    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca413a0}
     9600    0.020    0.000    0.020    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9ddee0}
25200/500    0.020    0.000    0.680    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
    50400    0.019    0.000    0.019    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1696(__getattr__)
     1800    0.016    0.000    0.387    0.000 /Users/tgg/Github/atr_igor/transformer.py:66(forward)
     3200    0.015    0.000    0.015    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca1fa60}
25200/500    0.014    0.000    0.679    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:241(forward)
     3200    0.012    0.000    0.112    0.000 /Users/tgg/Github/atr_igor/transformer.py:47(forward)
     9700    0.012    0.000    0.289    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/linear.py:115(forward)
25200/500    0.011    0.000    0.680    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
   149800    0.010    0.000    0.010    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:212(is_tracing_enabled)
     9000    0.009    0.000    0.009    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca43740}
     1800    0.009    0.000    0.104    0.000 /Users/tgg/Github/atr_igor/transformer.py:77(attention)
     4200    0.008    0.000    0.014    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/functional.py:1279(dropout)
     7200    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca58900}
     5000    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9f1260}
     3400    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9dfce0}
     3000    0.007    0.000    0.641    0.000 /Users/tgg/Github/atr_igor/transformer.py:26(forward)
     1800    0.006    0.000    0.006    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca1f420}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9de7a0}

****************************************************************************************************
Keras Inference Time: 2.867401123046875
         3564959 function calls (3444008 primitive calls) in 2.861 seconds

   Ordered by: internal time
   List reduced from 1691 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    1.592    0.016    1.592    0.016 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
      400    0.144    0.000    0.145    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py:70(convert_to_eager_tensor)
     3083    0.040    0.000    0.092    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:272(unwrap)
460826/460807    0.031    0.000    0.096    0.000 {built-in method builtins.isinstance}
   223876    0.024    0.000    0.029    0.000 {built-in method builtins.hasattr}
     5218    0.020    0.000    0.028    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/typing.py:1911(_get_protocol_attrs)
   236394    0.018    0.000    0.019    0.000 {built-in method builtins.getattr}
    63091    0.017    0.000    0.029    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:187(_has_tf_decorator_attr)
     2450    0.016    0.000    0.063    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:959(_create_c_op)
77423/13928    0.014    0.000    0.031    0.000 {built-in method builtins.hash}
     2705    0.013    0.000    0.118    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:179(_get_bound_instance)
   1174/5    0.013    0.000    0.858    0.172 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/keras/src/engine/base_layer.py:1005(__call__)
    36466    0.013    0.000    0.029    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py:793(as_dtype)
     2450    0.012    0.000    0.012    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_FinishOperation}
     1374    0.012    0.000    0.380    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:752(_apply_op_helper)
     2439    0.012    0.000    0.013    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_OperationGetAttrValueProto}
   180059    0.011    0.000    0.011    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:343(decorated_target)
     1374    0.010    0.000    0.120    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:411(_ExtractInputsAndAttrs)
5822/5520    0.010    0.000    0.040    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:2428(_signature_from_callable)
    23844    0.009    0.000    0.025    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:55(<genexpr>)
     2705    0.008    0.000    0.173    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:115(make_decorator)
     2450    0.008    0.000    0.178    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:2593(_create_op_internal)
     3172    0.008    0.000    0.020    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:2333(_signature_from_function)
    60699    0.008    0.000    0.011    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:300(ismethod)
     2450    0.008    0.000    0.108    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:1051(from_node_def)

What is also strange is the number of operations between torch and keras.

AlexanderLutsenko commented 4 months ago

When Tensorflow performance sucks, these are the usual culprits:

  1. In Pytorch, transformers typically call scaled_dot_product_attention which may leverage highly optimized kernels (e.g. FlashAttention). Sadly, there is no such thing in Tensorflow, so Nobuco computes attention with a naive algorithm.
  2. Advanced tensor slicing, Tensorflow lacks good implementation for it. You can run this example and see how bulky the output graph is.
thegodone commented 4 months ago

I use as much as I can tf.keras instead of keras import

I compare the two code speeds:

def converter_scaled_dot_product_attention1(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        D = tf.shape(query)[-1]

        if scale is None:
            scale = tf.cast(D, query.dtype) ** -0.5

        # Corby's numerically more stable attention
        # See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
        s_scale = tf.cast(tf.sqrt(scale), query.dtype)
        query = query * s_scale
        key = key * s_scale

        sim = tf.matmul(query, key, transpose_b=True)

        if attn_mask is not None:
            sim += attn_mask * -1e9
        elif is_causal:
            L = tf.shape(query)[-2]
            S = tf.shape(key)[-2]
            causal_mask = tf.linalg.band_part(tf.ones((L, S)), -1, 0)
            sim = sim * causal_mask + (1.0 - causal_mask) * -1e9

        attn = tf.nn.softmax(sim, axis=-1)
        if dropout_p>0:
            attn = Dropout(dropout_p)(attn)

        return tf.matmul(attn, value)

    return func

and your code (little modified):


def tril(h, w):
    y = tf.range(0, h)[:, None]
    x = tf.range(0, w)[None, :]
    return y >= x

def converter_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        D = tf.shape(query)[-1]

        if scale is None:
            scale = tf.cast(D, query.dtype) ** -0.5

        # Corby's numerically more stable attention
        # See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
        s_scale = tf.cast(tf.sqrt(scale), query.dtype)
        query = query * s_scale
        key = key * s_scale

        sim = query @ tf.experimental.numpy.swapaxes(key, -2, -1)

        if attn_mask is not None:
            sim = tf.where(attn_mask, sim, float("-inf"))
        elif is_causal:
            L = tf.shape(query)[-2]
            S = tf.shape(key)[-2]
            causal_mask = tril(L, S)
            sim = tf.where(causal_mask, sim, float("-inf"))

        attn = tf.nn.softmax(sim, axis=-1)
        attn = tf.keras.layers.Dropout(dropout_p)(attn)
        return attn @ value
    return func

I got almost same speed with a very little improvement in version "1"

AlexanderLutsenko commented 4 months ago

@thegodone One important thing I almost forgot about: Tensorflow really hates dynamic tensor shapes. To infer language models with varying context length, you should do input padding (see this example and the accompanying issue).

AlexanderLutsenko commented 4 months ago

Turns out, the inference is much faster if the Keras model is exported as SavedModel artifact:

keras_model.export(model_path)

saved_model = tf.saved_model.load(model_path)
saved_model.serve(inputs)
thegodone commented 4 months ago

nice catch, look like there are optimization during the savedmodel. I will try that thanks a lot

thegodone commented 4 months ago

indeed almost all is faster now, except the "first line" built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute this is strange:

(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.6295859813690186
Generator: tensor([[-29.3162,  -8.5767,  -8.5767,  -8.1641,  -8.7406,  -8.5813,  -8.7283,
          -8.8958,  -7.8562,  -8.5062,  -8.2777,  -8.4021,  -8.7413,  -8.6585,
          -8.9101,  -8.7277,  -8.8235,  -7.7042,  -8.4626,  -7.2766,  -0.8226,
          -6.3867,  -8.3603,  -7.5078,  -8.5743,  -8.7478,  -4.1617,  -1.3617,
          -7.9011,  -7.7502,  -5.9109,  -8.5652,  -1.3151,  -8.6553,  -8.7373,
          -8.6551,  -8.6658,  -8.5221,  -4.9030,  -7.4738,  -8.4073,  -8.0457]],
       grad_fn=<SelectBackward0>)
****************************************************************************************************
Keras Inference Time: 2.041451930999756
Generator: tf.Tensor(
[[-29.316212   -8.576733   -8.576746   -8.164137   -8.740633   -8.581325
   -8.728261   -8.89583    -7.85616    -8.506172   -8.277732   -8.402089
   -8.741335   -8.658456   -8.910144   -8.727736   -8.823455   -7.704239
   -8.462603   -7.2765865  -0.8225823  -6.386689   -8.360339   -7.5077925
   -8.574331   -8.747827   -4.161747   -1.3616791  -7.901073   -7.7501645
   -5.910877   -8.565204   -1.3150749  -8.655296   -8.737296   -8.655108
   -8.665803   -8.522052   -4.9029703  -7.4737854  -8.407324   -8.04574  ]], shape=(1, 42), dtype=float32)
****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16ef13880>
Traceback (most recent call last):
  File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
Traceback (most recent call last):
  File "/Users/tgg/Github/atr_igor/testkeras2_savedmodel.py", line 113, in <module>
    profiler = cProfile.Profile()
               ^^^^^^^^
NameError: name 'cProfile' is not defined
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.7369470596313477
         558604 function calls (471804 primitive calls) in 0.737 seconds

   Ordered by: internal time
   List reduced from 76 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     9700    0.276    0.000    0.276    0.000 {built-in method torch._C._nn.linear}
124600/111900    0.053    0.000    0.560    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
     1800    0.046    0.000    0.046    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f40f40}
     3200    0.039    0.000    0.039    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f41940}
        1    0.030    0.030    0.737    0.737 /Users/tgg/Github/atr_igor/testkeras2_savedmodel.py:80(profile_pytorch_inference)
     3600    0.029    0.000    0.029    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf740}
25200/500    0.022    0.000    0.699    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
    50400    0.020    0.000    0.020    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1696(__getattr__)
     9600    0.020    0.000    0.020    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edda80}
     1800    0.017    0.000    0.411    0.000 /Users/tgg/Github/atr_igor/transformer.py:66(forward)
25200/500    0.014    0.000    0.698    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:241(forward)
     3200    0.013    0.000    0.013    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f23600}
     3200    0.012    0.000    0.109    0.000 /Users/tgg/Github/atr_igor/transformer.py:47(forward)
     9700    0.011    0.000    0.297    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/linear.py:115(forward)
25200/500    0.011    0.000    0.699    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
   149800    0.010    0.000    0.010    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:212(is_tracing_enabled)
     1800    0.009    0.000    0.121    0.000 /Users/tgg/Github/atr_igor/transformer.py:77(attention)
     9000    0.009    0.000    0.009    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f432e0}
     4200    0.008    0.000    0.014    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/functional.py:1279(dropout)
     7200    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f584a0}
     5000    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ef0e00}
     3000    0.007    0.000    0.664    0.000 /Users/tgg/Github/atr_igor/transformer.py:26(forward)
     3400    0.006    0.000    0.006    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf880}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f22fc0}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ede340}

****************************************************************************************************
Keras Inference Time: 2.2881689071655273
         1447802 function calls (1311645 primitive calls) in 2.286 seconds

   Ordered by: internal time
   List reduced from 1591 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    1.671    0.017    1.671    0.017 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
      400    0.128    0.000    0.129    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py:70(convert_to_eager_tensor)
      667    0.030    0.000    0.030    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_FinishOperation}
      168    0.029    0.000    0.029    0.000 {method '_numpy_internal' of 'tensorflow.python.framework.ops.EagerTensor' objects}
10721/155    0.016    0.000    0.096    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:488(generic_visit)
103263/30224    0.015    0.000    0.033    0.000 {built-in method builtins.hash}
178603/178541    0.012    0.000    0.032    0.000 {built-in method builtins.isinstance}
   113900    0.010    0.000    0.011    0.000 {built-in method builtins.getattr}
 5098/199    0.009    0.000    0.021    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/ast_util.py:33(copy)
    34870    0.008    0.000    0.024    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor.py:894(__hash__)
  6316/44    0.008    0.000    0.090    0.002 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/transformer.py:417(visit)
 16072/82    0.008    0.000    0.123    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:414(visit)
    59573    0.007    0.000    0.007    0.000 {built-in method builtins.hasattr}
    26280    0.007    0.000    0.012    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/anno.py:130(hasanno)
    46432    0.007    0.000    0.009    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:255(iter_fields)
     1098    0.004    0.000    0.006    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/typing.py:1911(_get_protocol_attrs)
    24987    0.004    0.000    0.005    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/enum.py:1230(__hash__)
    34870    0.004    0.000    0.005    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor_shape.py:1508(__hash__)
      667    0.004    0.000    0.043    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:959(_create_c_op)
      231    0.004    0.000    0.004    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_GraphCopyFunction}
     3502    0.004    0.000    0.009    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
      797    0.003    0.000    0.004    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_OperationGetAttrValueProto}
  2030/98    0.003    0.000    0.013    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/gast/astn.py:17(generic_visit)
      200    0.003    0.000    0.006    0.000 /Users/tgg/Github/atr_igor/data.py:237(pad_pack)
     7339    0.003    0.000    0.003    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py:264(__eq__)

****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16c817880>
Traceback (most recent call last):
  File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable
johndpope commented 1 month ago

probably deserves mentioning on readme that performances suffers until model is saved.