Closed richardsliu closed 1 month ago
@FanhaiLu1 Still seeing the same error after the latest fix, can you take a look?
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2656, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): [36mray::PyTorchRayWorker.prefill_ray()[39m (pid=904, ip=10.36.8.6, actor_id=7744b1669cd8fdc6809a72a502000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7f5e3259e380>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
Prefill_ray() now returns a
[result, first_token]
tuple, wherefirst_token
contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely: