google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Ray engine crashes on multihost when fetching Jax.array from prefill_ray #150

Closed richardsliu closed 1 month ago

richardsliu commented 2 months ago

Prefill_ray() now returns a [result, first_token] tuple, where first_token contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:

job_id:06000000
:actor_name:ServeReplica:default:JetStreamDeployment
SIGTERM handler is not set because current thread is not the main thread.
Using address example-cluster-kuberay-head-svc.default.svc.cluster.local:6379 set in the environment variable RAY_ADDRESS
Connecting to existing Ray cluster at address: example-cluster-kuberay-head-svc.default.svc.cluster.local:6379...
Calling ray.init() again after it has already been called.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 162, in run
    super().run()
  File "/home/ray/anaconda3/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 507, in _prefill_thread
    prefill_result, first_token = prefill_engine.prefill(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 83, in prefill
    return self.prefill_impl(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 113, in prefill_impl
    results = ray.get(all_outputs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2623, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::PyTorchRayWorker.prefill_ray() (pid=14601, ip=10.104.7.5, actor_id=0721a490262f0d248878f59d06000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7974fc14e410>)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
    fun, args, arr_state = self._value.__reduce__()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
    raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
richardsliu commented 1 month ago

@FanhaiLu1 Still seeing the same error after the latest fix, can you take a look?

 return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2656, in get
 values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
 raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::PyTorchRayWorker.prefill_ray() (pid=904, ip=10.36.8.6, actor_id=7744b1669cd8fdc6809a72a502000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7f5e3259e380>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
 cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
 return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
 fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
 return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
 raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.