ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.72k stars 5.73k forks source link

[rllib] KeyError: 'policy_ids' when loading DreamerV3 Checkpoint #42393

Open defrag-bambino opened 9 months ago

defrag-bambino commented 9 months ago

What happened + What you expected to happen

I am having issues loading a DreamerV3 checkpoint for inference. Similar to what was discussed in #40312, I assume it has to do with the old/new API. When trying to load the DreamerV3 checkpoint, this error appears:

Traceback (most recent call last):
  File "/home/fabian/Desktop/ray_rllib/dreamerv3_inference.py", line 38, in <module>
    algo = Algorithm.from_checkpoint(args.checkpoint)
  File "/home/fabian/miniconda3/envs/fpvPy310ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 337, in from_checkpoint
    state = Algorithm._checkpoint_info_to_algorithm_state(
  File "/home/fabian/miniconda3/envs/fpvPy310ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 2905, in _checkpoint_info_to_algorithm_state
    policy_ids if policy_ids is not None else worker_state["policy_ids"]
KeyError: 'policy_ids'

Versions / Dependencies

ray 2.9.0

Reproduction script

Essentiall just the example script from the DreamerV3 doc here, but loading the checkpoint and not creating the algo from scratch:

import gymnasium as gym
import numpy as np

from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
from ray.rllib.core.models.base import STATE_IN, STATE_OUT
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

env_name = "CartPole-v1"
# Use the vector env API.
env = gym.vector.make(env_name, num_envs=1, asynchronous=False)

terminated = truncated = False
# Reset the env.
obs, _ = env.reset()
# Every time, we start a new episode, we should set is_first to True for the upcoming
# action inference.
is_first = 1.0

from ray.rllib.algorithms.algorithm import Algorithm
algo = Algorithm.from_checkpoint("PATH/TO/YOUR/DREAMERV3/CHECKPOINT")

# Extract the actual RLModule from the local (Dreamer) EnvRunner.
rl_module = algo.workers.local_worker().module
# Get initial states from RLModule (note that these are always B=1, so this matches
# our num_envs=1; if you are using a vector env >1, you would have to repeat the
# returned states `num_env` times to get the correct batch size):
states = rl_module.get_initial_state()

while not terminated and not truncated:
    # Use the RLModule for action computations directly.
    # DreamerV3 expects this particular batch format: obs, prev. states and the
    # `is_first` flag.
    batch = {
        # states is already batched (B=1)
        STATE_IN: states,
        # obs is already batched (due to vector env).
        SampleBatch.OBS: tf.convert_to_tensor(obs),
        # set to True at beginning of episode.
        "is_first": tf.convert_to_tensor([is_first]),
    }
    outs = rl_module.forward_inference(batch)
    # Extract actions (which are in one hot format) and state-outs from outs
    actions = np.argmax(outs[SampleBatch.ACTIONS].numpy(), axis=-1)
    states = outs[STATE_OUT]

    # Perform a step in the env.
    obs, reward, terminated, truncated, info = env.step(actions)
    # Not at the beginning of the episode anymore.
    is_first = 0.0

Issue Severity

Low: It annoys or frustrates me.

KrTG commented 8 months ago

Using nightly build fixed it for me

simonsays1980 commented 7 months ago

@defrag-bambino Could you try the newest version of Ray and see if the error prevails? Has solved it for @KrTG