jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Provide better error message when using device_put_sharded inside jitted code #6221

Open GeorgOstrovski opened 3 years ago

GeorgOstrovski commented 3 years ago

This currently results in

TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

as opposed to regular device_put which doesn't do anything when in jitted code but doesn't fail either. The message could be more helpful here.

selamw1 commented 9 months ago

Hi @GeorgOstrovski

To illustrate the use of device_put_sharded within JIT-compiled code, I created an example. When executed with JAX version 0.4.23 on both CPU and GPU,

import jax
import numpy as np
import jax.numpy as jnp

@jax.jit
def sharded_computation():
  devices = jax.local_devices()
  x = [jnp.ones(5) for device in devices]
  y = jax.device_put_sharded(x, devices)
  return np.allclose(y, jnp.stack(x))

result = sharded_computation()
print(result)  

it produced the following error message: XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

However, when running the same code with JAX version 0.3.25 on TPU, it still produced the following error message, as you mentioned: TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

For further details, please refer to the provided gists for CPU/GPU and TPU executions.

selamw1 commented 3 weeks ago

The latest JAX version 0.4.33 produce the same error message for CPU/GPU/TPU hardware accelerator:

XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>