Open GeorgOstrovski opened 3 years ago
Hi @GeorgOstrovski
To illustrate the use of device_put_sharded
within JIT-compiled code, I created an example. When executed with JAX version 0.4.23 on both CPU and GPU,
import jax
import numpy as np
import jax.numpy as jnp
@jax.jit
def sharded_computation():
devices = jax.local_devices()
x = [jnp.ones(5) for device in devices]
y = jax.device_put_sharded(x, devices)
return np.allclose(y, jnp.stack(x))
result = sharded_computation()
print(result)
it produced the following error message:
XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
However, when running the same code with JAX version 0.3.25 on TPU, it still produced the following error message, as you mentioned:
TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
For further details, please refer to the provided gists for CPU/GPU and TPU executions.
The latest JAX version 0.4.33 produce the same error message for CPU/GPU/TPU hardware accelerator:
XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
This currently results in
TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
as opposed to regular
device_put
which doesn't do anything when in jitted code but doesn't fail either. The message could be more helpful here.