google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

How to pass data carried by a class through forward pass #1719

Closed aycandv closed 2 years ago

aycandv commented 2 years ago

Description

When performing linear regression, we normally pass data in forward pass as a JAX array type. I wonder if it is possible to pass data as a custom object (i.e. a class). When I try to do optimization using a custom data-carrier class, trax raises an exception:

TypeError: Value 'ShapeDtype{shape:(50,), dtype:object}' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

Is there a way to make trax recognize my custom object as a JAX compatible object?

You can reach the code that generates the exception, and experiment with it here: https://colab.research.google.com/drive/18_1lN7k5psv8kNdYqeIaToqblKEbWhKb?usp=sharing

Environment information

OS: Colab Ubuntu

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.7.0-cp37-cp37m-linux_x86_64.whl
tensorflow-datasets==4.0.1
tensorflow-estimator==2.7.0
tensorflow-gcs-config==2.7.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.15.0
tensorflow-text==2.7.3

$ pip freeze | grep jax
jax==0.2.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl

$ python -V
Python 3.7.12

For bugs: reproduction and error logs

Steps to reproduce:

You can reach the code that generates the exception, and experiment with it here: https://colab.research.google.com/drive/18_1lN7k5psv8kNdYqeIaToqblKEbWhKb?usp=sharing

# Error logs:
LayerError: Exception passing through layer Serial (in _forward_abstract):
  layer created in file [...]/<ipython-input-12-69f735998ff5>, line 1
  layer input shapes: (ShapeDtype{shape:(50,), dtype:object}, ShapeDtype{shape:(50,), dtype:object}, ShapeDtype{shape:(50,), dtype:float32})

  File [...]/jax/_src/util.py, line 43, in safe_map
    return list(map(f, *args))

  File [...]/jax/_src/api_util.py, line 315, in shaped_abstractify
    return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,

  File [...]/jax/_src/api_util.py, line 303, in _dtype
    return dtypes.result_type(x)

  File [...]/jax/_src/dtypes.py, line 369, in result_type
    return canonicalize_dtype(_lattice_result_type(*args)[0])

  File [...]/jax/_src/dtypes.py, line 351, in _lattice_result_type
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))

  File [...]/jax/_src/dtypes.py, line 351, in <genexpr>
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))

  File [...]/jax/_src/dtypes.py, line 246, in _dtype_and_weaktype
    return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)

  File [...]/jax/_src/dtypes.py, line 346, in dtype
    raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "

TypeError: Value 'ShapeDtype{shape:(50,), dtype:object}' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.