Open phate09 opened 3 weeks ago
This seems to fail here
because params
contains list
s as values:
{'devices': [None], 'srcs': [None], 'copy_semantics': [<CopySemantics.ALIAS: 1>]}
@yashk2810 does this look right?
The general advice here is to use jax.jit
+ shard_map
instead. pmap
is in maintenance mode, so I expect such stuff to break. I can fix it for this instance but highly recommend you to transition to more supported APIs.
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
In this case, this works:
In [1]: import jax, jax.numpy as jnp, numpy as np, jax._src.test_util as jtu
In [2]: @jax.jit
...: def create_from_numpy_arr(arr):
...: return jnp.asarray(np.array([[0, 0, 0]]))
...:
...:
...: arr = jnp.array([1])
...: with jax.disable_jit():
...: create_from_numpy_arr(arr)
...:
In [3]:
Description
linked probably to #3401 here is a reproducible example
you will get error
TypeError: unhashable type: 'list'
System info (python version, jaxlib version, accelerator, etc.)
python 3.11.0 jax 0.4.33