When I have set a TF backend (2.11.0) to computation and try to use SelfAttention from research module than I receive Exception below:
File "/usr/local/lib/python3.10/dist-packages/trax/layers/research/efficient_attention.py", line 1536, in
lambda x: np.issubdtype(x.dtype, np.inexact), inputs)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float32' as a data type
Probbaly there is problem that TF datatype is not appopriate subtype. Shortly:
jax.np.issubdtype(tf.float64, np.floating)
Gives:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 1, in
jax.np.issubdtype(tf.float64, np.floating)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float64' as a data type
# Steps to reproduce:
jax.np.issubdtype(tf.float64, np.floating)
Gives:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-19-b62b80ea3d79>", line 1, in <module>
jax.np.issubdtype(tf.float64, np.floating)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float64' as a data type
This situation happens in efficient_attention.py package during computation:
inputs_is_differentiable = fastmath.nested_map(
lambda x: np.issubdtype(x.dtype, np.inexact), inputs)
Description
When I have set a TF backend (2.11.0) to computation and try to use SelfAttention from research module than I receive Exception below:
File "/usr/local/lib/python3.10/dist-packages/trax/layers/research/efficient_attention.py", line 1536, in
lambda x: np.issubdtype(x.dtype, np.inexact), inputs)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float32' as a data type
Probbaly there is problem that TF datatype is not appopriate subtype. Shortly:
jax.np.issubdtype(tf.float64, np.floating)
Gives:
Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 1, in
jax.np.issubdtype(tf.float64, np.floating)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float64' as a data type
Environment information
For bugs: reproduction and error logs
Error logs:
Traceback ebowe