When performing linear regression, we normally pass data in forward pass as a JAX array type. I wonder if it is possible to pass data as a custom object (i.e. a class). When I try to do optimization using a custom data-carrier class, trax raises an exception:
TypeError: Value 'ShapeDtype{shape:(50,), dtype:object}' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
Is there a way to make trax recognize my custom object as a JAX compatible object?
# Error logs:
LayerError: Exception passing through layer Serial (in _forward_abstract):
layer created in file [...]/<ipython-input-12-69f735998ff5>, line 1
layer input shapes: (ShapeDtype{shape:(50,), dtype:object}, ShapeDtype{shape:(50,), dtype:object}, ShapeDtype{shape:(50,), dtype:float32})
File [...]/jax/_src/util.py, line 43, in safe_map
return list(map(f, *args))
File [...]/jax/_src/api_util.py, line 315, in shaped_abstractify
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
File [...]/jax/_src/api_util.py, line 303, in _dtype
return dtypes.result_type(x)
File [...]/jax/_src/dtypes.py, line 369, in result_type
return canonicalize_dtype(_lattice_result_type(*args)[0])
File [...]/jax/_src/dtypes.py, line 351, in _lattice_result_type
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File [...]/jax/_src/dtypes.py, line 351, in <genexpr>
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File [...]/jax/_src/dtypes.py, line 246, in _dtype_and_weaktype
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
File [...]/jax/_src/dtypes.py, line 346, in dtype
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
TypeError: Value 'ShapeDtype{shape:(50,), dtype:object}' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
Description
When performing linear regression, we normally pass data in forward pass as a JAX array type. I wonder if it is possible to pass data as a custom object (i.e. a class). When I try to do optimization using a custom data-carrier class, trax raises an exception:
Is there a way to make trax recognize my custom object as a JAX compatible object?
You can reach the code that generates the exception, and experiment with it here: https://colab.research.google.com/drive/18_1lN7k5psv8kNdYqeIaToqblKEbWhKb?usp=sharing
Environment information
For bugs: reproduction and error logs
Steps to reproduce:
You can reach the code that generates the exception, and experiment with it here: https://colab.research.google.com/drive/18_1lN7k5psv8kNdYqeIaToqblKEbWhKb?usp=sharing