value = 0.9189385332046727, dtype = <class 'numpy.float32'>
def _default_convert_to_tensor(value, dtype=None):
"""Default tensor conversion function for array, bool, int, float, and complex."""
if JAX_MODE:
# TODO(b/223267515): We shouldn't need to specialize here.
if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
> value.dtype, jax.dtypes.prng_key
):
E AttributeError: module 'jax.dtypes' has no attribute 'prng_key'
Receiving this: