jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

pmap fails under disable_jit() (TypeError: unhashable type: 'list') #24642

Open phate09 opened 3 weeks ago

phate09 commented 3 weeks ago

Description

linked probably to #3401 here is a reproducible example

import jax
import jax.numpy as jnp
import numpy as np
import chex

@jax.pmap
def create_from_numpy_arr(arr) -> chex.Array:
    return jnp.asarray(np.array([[0, 0, 0]]))

arr = jnp.array([1])
with jax.disable_jit():
    create_from_numpy_arr(arr)

you will get error TypeError: unhashable type: 'list'

System info (python version, jaxlib version, accelerator, etc.)

python 3.11.0 jax 0.4.33

superbobry commented 3 weeks ago

This seems to fail here

https://github.com/jax-ml/jax/blob/8abedda8a62d0e3bd8babcc0c588251012bd5af7/jax/_src/interpreters/pxla.py#L458-L459

because params contains lists as values:

{'devices': [None], 'srcs': [None], 'copy_semantics': [<CopySemantics.ALIAS: 1>]}

@yashk2810 does this look right?

yashk2810 commented 3 weeks ago

The general advice here is to use jax.jit + shard_map instead. pmap is in maintenance mode, so I expect such stuff to break. I can fix it for this instance but highly recommend you to transition to more supported APIs.

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

https://jax.readthedocs.io/en/latest/notebooks/shard_map.html

yashk2810 commented 3 weeks ago

In this case, this works:

In [1]: import jax, jax.numpy as jnp, numpy as np, jax._src.test_util as jtu

In [2]: @jax.jit
   ...: def create_from_numpy_arr(arr):
   ...:     return jnp.asarray(np.array([[0, 0, 0]]))
   ...:
   ...:
   ...: arr = jnp.array([1])
   ...: with jax.disable_jit():
   ...:     create_from_numpy_arr(arr)
   ...:

In [3]: