google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

SelfAttention - problem with tensorflow 2.11.0 #1771

Open mmarcinmichal opened 1 year ago

mmarcinmichal commented 1 year ago

Description

When I have set a TF backend (2.11.0) to computation and try to use SelfAttention from research module than I receive Exception below:

File "/usr/local/lib/python3.10/dist-packages/trax/layers/research/efficient_attention.py", line 1536, in lambda x: np.issubdtype(x.dtype, np.inexact), inputs) File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype arg1 = dtype(arg1).type TypeError: Cannot interpret 'tf.float32' as a data type

Probbaly there is problem that TF datatype is not appopriate subtype. Shortly:

jax.np.issubdtype(tf.float64, np.floating)

Gives:

Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 1, in jax.np.issubdtype(tf.float64, np.floating) File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype arg1 = dtype(arg1).type TypeError: Cannot interpret 'tf.float64' as a data type

Environment information

OS: Windows and WSL (Ubuntu 20.04.5 LTS)

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.1
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0

$ pip freeze | grep jax
jax==0.4.1
jaxlib==0.4.1+cuda11.cudnn86

$ python -V
Python 3.10.9

For bugs: reproduction and error logs

# Steps to reproduce:
jax.np.issubdtype(tf.float64, np.floating)

Gives:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-b62b80ea3d79>", line 1, in <module>
    jax.np.issubdtype(tf.float64, np.floating)
  File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
    arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float64' as a data type

This situation happens in efficient_attention.py package during computation:

inputs_is_differentiable = fastmath.nested_map(
          lambda x: np.issubdtype(x.dtype, np.inexact), inputs)

Error logs:

Traceback ebowe