I am having issues loading a DreamerV3 checkpoint for inference. Similar to what was discussed in #40312, I assume it has to do with the old/new API.
When trying to load the DreamerV3 checkpoint, this error appears:
Traceback (most recent call last):
File "/home/fabian/Desktop/ray_rllib/dreamerv3_inference.py", line 38, in <module>
algo = Algorithm.from_checkpoint(args.checkpoint)
File "/home/fabian/miniconda3/envs/fpvPy310ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 337, in from_checkpoint
state = Algorithm._checkpoint_info_to_algorithm_state(
File "/home/fabian/miniconda3/envs/fpvPy310ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 2905, in _checkpoint_info_to_algorithm_state
policy_ids if policy_ids is not None else worker_state["policy_ids"]
KeyError: 'policy_ids'
Versions / Dependencies
ray 2.9.0
Reproduction script
Essentiall just the example script from the DreamerV3 doc here, but loading the checkpoint and not creating the algo from scratch:
import gymnasium as gym
import numpy as np
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
from ray.rllib.core.models.base import STATE_IN, STATE_OUT
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
env_name = "CartPole-v1"
# Use the vector env API.
env = gym.vector.make(env_name, num_envs=1, asynchronous=False)
terminated = truncated = False
# Reset the env.
obs, _ = env.reset()
# Every time, we start a new episode, we should set is_first to True for the upcoming
# action inference.
is_first = 1.0
from ray.rllib.algorithms.algorithm import Algorithm
algo = Algorithm.from_checkpoint("PATH/TO/YOUR/DREAMERV3/CHECKPOINT")
# Extract the actual RLModule from the local (Dreamer) EnvRunner.
rl_module = algo.workers.local_worker().module
# Get initial states from RLModule (note that these are always B=1, so this matches
# our num_envs=1; if you are using a vector env >1, you would have to repeat the
# returned states `num_env` times to get the correct batch size):
states = rl_module.get_initial_state()
while not terminated and not truncated:
# Use the RLModule for action computations directly.
# DreamerV3 expects this particular batch format: obs, prev. states and the
# `is_first` flag.
batch = {
# states is already batched (B=1)
STATE_IN: states,
# obs is already batched (due to vector env).
SampleBatch.OBS: tf.convert_to_tensor(obs),
# set to True at beginning of episode.
"is_first": tf.convert_to_tensor([is_first]),
}
outs = rl_module.forward_inference(batch)
# Extract actions (which are in one hot format) and state-outs from outs
actions = np.argmax(outs[SampleBatch.ACTIONS].numpy(), axis=-1)
states = outs[STATE_OUT]
# Perform a step in the env.
obs, reward, terminated, truncated, info = env.step(actions)
# Not at the beginning of the episode anymore.
is_first = 0.0
What happened + What you expected to happen
I am having issues loading a DreamerV3 checkpoint for inference. Similar to what was discussed in #40312, I assume it has to do with the old/new API. When trying to load the DreamerV3 checkpoint, this error appears:
Versions / Dependencies
ray 2.9.0
Reproduction script
Essentiall just the example script from the DreamerV3 doc here, but loading the checkpoint and not creating the algo from scratch:
Issue Severity
Low: It annoys or frustrates me.