ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.05k stars 5.59k forks source link

Cannot create RLPredictor using restored checkpoint in different Ray session #33995

Open stefan4444 opened 1 year ago

stefan4444 commented 1 year ago

Ray generates the following error when creating an RLPredictor using a restored checkpoint in a different Ray session from the one that created the checkpoint. This error does not occur when creating an RLPredictor using a restored checkpoint in the same Ray session that created the checkpoint. Please see repro script below.

Traceback (most recent call last):
  File "/home/ec2-user/environment/restore_checkpoint.py", line 130, in <module>
    predictor = RLPredictor.from_checkpoint(checkpoint)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/train/rl/rl_predictor.py", line 63, in from_checkpoint
    policy = checkpoint.get_policy(env)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/train/rl/rl_checkpoint.py", line 42, in get_policy
    return Policy.from_checkpoint(checkpoint=self)["default_policy"]
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 261, in from_checkpoint
    policy_state = pickle.load(f)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/_private/serialization.py", line 89, in _actor_handle_deserializer
    return ray.actor.ActorHandle._deserialization_helper(serialized_obj, outer_id)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/actor.py", line 1281, in _deserialization_helper
    state, outer_object_ref
  File "python/ray/_raylet.pyx", line 2319, in ray._raylet.CoreWorker.deserialize_and_register_actor_handle
  File "python/ray/_raylet.pyx", line 2288, in ray._raylet.CoreWorker.make_actor_handle
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/_private/function_manager.py", line 523, in load_actor_class
    job_id, actor_creation_function_descriptor
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/_private/function_manager.py", line 617, in _load_actor_class_from_gcs
    class_name = ensure_str(class_name)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/ray/_private/utils.py", line 293, in ensure_str
    assert isinstance(s, bytes)
AssertionError
import gymnasium as gym
import numpy as np
import os
import shutil

import ray._private.utils

from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
from ray.rllib.offline.json_writer import JsonWriter

import ray
from ray.air import Checkpoint
from ray.air import CheckpointConfig
from ray.air.config import RunConfig, ScalingConfig
from ray.train.rl import RLTrainer
from ray.rllib.algorithms.bc.bc import BC
from ray.train.rl.rl_predictor import RLPredictor

### start ray
ray.init()

### offline data directory
offline_data_dir = os.path.join(
    ray._private.utils.get_user_temp_dir(),
    'cartpole_offline_data'
)

### remove offline data directory if it exists
if os.path.exists(offline_data_dir) and os.path.isdir(offline_data_dir):
    shutil.rmtree(offline_data_dir)

### create environment
env = gym.make('CartPole-v1')

### create pre-processor
prep = get_preprocessor(env.observation_space)(env.observation_space)

### save 100 sample trajectories to offline data
batch_builder = SampleBatchBuilder()
writer = JsonWriter(offline_data_dir)
for eps_id in range(100):
    obs, info = env.reset()
    prev_action = np.zeros_like(env.action_space.sample())
    prev_reward = 0
    terminated = truncated = False
    t = 0
    while not terminated and not truncated:
        action = env.action_space.sample()
        new_obs, rew, terminated, truncated, info = env.step(action)
        batch_builder.add_values(
            t=t,
            eps_id=eps_id,
            agent_index=0,
            obs=prep.transform(obs),
            actions=action,
            action_prob=1.0,  # put the true action probability here
            action_logp=0.0,
            rewards=rew,
            prev_actions=prev_action,
            prev_rewards=prev_reward,
            terminateds=terminated,
            truncateds=truncated,
            infos=info,
            new_obs=prep.transform(new_obs),
        )
        obs = new_obs
        prev_action = action
        prev_reward = rew
        t += 1
    writer.write(batch_builder.build_and_reset())

### load offline data
dataset = ray.data.read_json(
    offline_data_dir,
    parallelism=2,
    ray_remote_args=dict(num_cpus=1)
)

### train rl model and save checkpoint at the end
trainer = RLTrainer(
    run_config=RunConfig(
        stop=dict(training_iteration=2),
        checkpoint_config=CheckpointConfig(checkpoint_at_end=True),
    ),
    scaling_config=ScalingConfig(
        num_workers=1,
        use_gpu=False,
    ),
    datasets={'train': dataset},
    algorithm=BC,
    config=dict(
        env='CartPole-v1',
        framework='torch',
        evaluation_num_workers=1,
        evaluation_interval=1,
        evaluation_config=dict(input='sampler')
    )
)
result = trainer.fit()

### get checkpoint path
checkpoint_path = result.checkpoint._local_path

### restore checkpoint
checkpoint = Checkpoint.from_directory(path=checkpoint_path)

### create predictor from checkpoint: THIS WORKS WHEN PERFORMED IN THE SAME RAY SESSION CHECKPOINT WAS CREATED IN
predictor = RLPredictor.from_checkpoint(checkpoint)

### close ray session
ray.shutdown()

### start new ray session
ray.init()

### restore checkpoint
checkpoint = Checkpoint.from_directory(path=checkpoint_path)

### create predictor from checkpoint: THIS DOES NOT WORK WHEN PERFORMED IN A DIFFERENT RAY SESSION FROM THE ONE THAT CHECKPOINT WAS CREATED IN
predictor = RLPredictor.from_checkpoint(checkpoint)
richardliaw commented 1 year ago

cc @sven1977

skourta commented 1 year ago

I am facing the same issue trying to restore a training interruption using Tuner.restore

    results = tuner.fit()
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/tuner.py", line 347, in fit
    return self._local_tuner.fit()
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/impl/tuner_internal.py", line 590, in fit
    analysis = self._fit_resume(trainable, param_space)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/impl/tuner_internal.py", line 738, in _fit_resume
    analysis = run(**args)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/tune.py", line 1036, in run
    runner = trial_runner_cls(**runner_kwargs)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/execution/tune_controller.py", line 149, in __init__
    super().__init__(
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/execution/trial_runner.py", line 258, in __init__
    self.resume(
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/execution/trial_runner.py", line 506, in resume
    trials = self.restore_from_dir()
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/execution/trial_runner.py", line 444, in restore_from_dir
    trial = Trial.from_json_state(trial_json_state)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/experiment/trial.py", line 1136, in from_json_state
    trial_state = json.loads(json_state, cls=TuneFunctionDecoder)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/json/__init__.py", line 359, in loads
    return cls(**kw).decode(s)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/json/decoder.py", line 353, in raw_decode
    obj, end = self.scan_once(s, idx)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/utils/serialization.py", line 39, in object_hook
    return self._from_cloudpickle(obj)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/tune/utils/serialization.py", line 43, in _from_cloudpickle
    return cloudpickle.loads(hex_to_binary(obj["value"]))
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/_private/serialization.py", line 105, in _actor_handle_deserializer
    return ray.actor.ActorHandle._deserialization_helper(serialized_obj, outer_id)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/actor.py", line 1292, in _deserialization_helper
    return worker.core_worker.deserialize_and_register_actor_handle(
  File "python/ray/_raylet.pyx", line 3503, in ray._raylet.CoreWorker.deserialize_and_register_actor_handle
  File "python/ray/_raylet.pyx", line 3472, in ray._raylet.CoreWorker.make_actor_handle
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/_private/function_manager.py", line 574, in load_actor_class
    actor_class = self._load_actor_class_from_gcs(
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/_private/function_manager.py", line 669, in _load_actor_class_from_gcs
    class_name = ensure_str(class_name)
  File "/scratch/sk10691/conda-envs/main/lib/python3.10/site-packages/ray/_private/utils.py", line 239, in ensure_str
    assert isinstance(s, bytes)
liuzhy71 commented 1 year ago

same issue here. Dont know is wrong

zk9907 commented 8 months ago

I face the same issue. python 3.7.12 ray 2.4.0