ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.74k stars 5.55k forks source link

[RLlib] Can't load trained model with PettingZoo leduc_holdem_v4 #32910

Open elliottower opened 1 year ago

elliottower commented 1 year ago

What happened + What you expected to happen

I'm having trouble loading a trained model using the PettingZoo env leduc_holdem_v4 (I'm working on updating the PettingZoo RLlib tutorials). I was able to train successfully using the train script below (reproduction scripts), and I tested training with the env registered as leduc_holdem as well as leduc_holdem_v4 in both files, neither worked. Running a very similar example with pistonball_v6 I registered the env as pistonball_v6 and it worked perfectly fine loading the model (will paste that code below as well).

Here's the error I keep getting, I believe t's not a Gymnasium or PettingZoo error as I am able to load the environments fine. When I do a separate file only registering the PettingZoo env I get no errors, so I think it has to do with the process of loading the model with Algorithm.from_checkpoint(). Previously there was a way to specify the env when loading a checkpoint, but it looks like you can't do that anymore (I'm assuming it does that automatically from the registered environment, which worked for pistonball but not this for some reason). The only other difference between these two examples is pistonball registers the env first as a ParallelPettingZooEnv, and then in the rendering file it registers it as a regular PettingZooEnv, whereas with leduc_holdem it registers them both as PettingZooEnv.

As a side note, the pistonball script doesn't run locally with ray unless you change RLlib's ParallelPettingZooEnv file as specified in this issue: https://github.com/Farama-Foundation/PettingZoo/issues/889 (removing the return_info argument which is now deprecated).

ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=16491, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f838134c100>)
  File "/Users/elliottower/anaconda3/envs/PettingZoo/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 569, in make
    _check_version_exists(ns, name, version)
  File "/Users/elliottower/anaconda3/envs/PettingZoo/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 219, in _check_version_exists
    _check_name_exists(ns, name)
  File "/Users/elliottower/anaconda3/envs/PettingZoo/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 197, in _check_name_exists
    raise error.NameNotFound(
gymnasium.error.NameNotFound: Environment leduc_holdem_v4 doesn't exist. 

During handling of the above exception, another exception occurred:

ray::RolloutWorker.__init__() (pid=16491, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f838134c100>)
  File "/Users/elliottower/anaconda3/envs/PettingZoo/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 607, in __init__
    self.env = env_creator(copy.deepcopy(self.env_context))
  File "/Users/elliottower/anaconda3/envs/PettingZoo/lib/python3.8/site-packages/ray/rllib/env/utils.py", line 178, in _gym_env_creator
    raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
ray.rllib.utils.error.EnvError: The env string you provided ('leduc_holdem_v4') is:
a) Not a supported/installed environment.
b) Not a tune-registered environment creator.
c) Not a valid env class string.

Try one of the following:
a) For Atari support: `pip install gym[atari] autorom[accept-rom-license]`.
   For VizDoom support: Install VizDoom
   (https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
   `pip install vizdoomgym`.
   For PyBullet support: `pip install pybullet`.
b) To register your custom env, do `from ray import tune;
   tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
   Then in your config, do `config['env'] = [name]`.
c) Make sure you provide a fully qualified classpath, e.g.:
   `ray.rllib.examples.env.repeat_after_me_env.RepeatAfterMeEnv`

During handling of the above exception, another exception occurred:

Versions / Dependencies

Gymnasium==0.26.3 numpy==1.23.5 Pillow==9.3.0 torch==1.13.1 SuperSuit==3.6.0 ray[rllib]==2.3.0 rlcard==1.1.0 PettingZoo==1.22.3 Pillow==9.4.0 tensorflow-probability==0.19.0

Reproduction script

rllib_leduc_holdem.py: (train script)

"""Uses Ray's RLLib to train agents to play Leduc Holdem.

Author: Rohan (https://github.com/Rohan138)
"""

import os

import gymnasium.wrappers
from gymnasium.spaces import Box, Discrete
import ray
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MAX
from ray.tune.registry import register_env

from pettingzoo.classic import leduc_holdem_v4

torch, nn = try_import_torch()

class TorchMaskedActions(DQNTorchModel):
    """PyTorch version of above ParametricActionsModel."""

    def __init__(self, obs_space: Box, action_space: Discrete, num_outputs, model_config, name, **kw):
        DQNTorchModel.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kw
        )

        obs_len = obs_space.shape[0] - action_space.n

        orig_obs_space = Box(
            shape=(obs_len,), low=obs_space.low[:obs_len], high=obs_space.high[:obs_len]
        )
        self.action_embed_model = TorchFC(
            orig_obs_space,
            action_space,
            action_space.n,
            model_config,
            name + "_action_embed",
        )

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the predicted action embedding
        action_logits, _ = self.action_embed_model(
            {"obs": input_dict["obs"]["observation"]}
        )
        # turns probit action mask into logit action mask
        inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)

        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()

if __name__ == "__main__":
    ray.init(local_mode=True)

    alg_name = "DQN"
    ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
    # function that outputs the environment you wish to register.

    def env_creator():
        env = leduc_holdem_v4.env()
        # env = gymnasium.wrappers.TimeLimit(env, max_episode_steps=200)
        # Throws error: AttributeError: 'OrderEnforcingWrapper' object has no attribute 'spec'
        return env

    env_name = "leduc_holdem_v4"
    register_env(env_name, lambda config: PettingZooEnv(env_creator()))

    test_env = PettingZooEnv(env_creator())
    obs_space = test_env.observation_space
    print(obs_space)
    act_space = test_env.action_space

    config = (
        DQNConfig()
        .environment(env=env_name)
        .rollouts(num_rollout_workers=1, rollout_fragment_length=30)
        .training(
            train_batch_size=200,
            hiddens=[],
            dueling=False,
            model={"custom_model": "pa_model"},
        )
        .multi_agent(
            policies={
                "player_0": (None, obs_space, act_space, {}),
                "player_1": (None, obs_space, act_space, {}),
            },
            policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
        )
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
        .debugging(log_level="DEBUG") # TODO: change to ERROR to match pistonball example
        .framework(framework="torch")
        .exploration(
            exploration_config={
                # The Exploration class to use.
                "type": "EpsilonGreedy",
                # Config for the Exploration class' constructor:
                "initial_epsilon": 0.1,
                "final_epsilon": 0.0,
                "epsilon_timesteps": 100000,  # Timesteps over which to anneal epsilon.
            }
        )
    )

    tune.run(
        alg_name,
        name="DQN",
        stop={"timesteps_total": 10000000},
        checkpoint_freq=10,
        config=config.to_dict(),
    )

render_rllib_leduc_holdem.py: (load script)

"""Uses Ray's RLLib to view trained agents playing Leduoc Holdem.

Author: Rohan (https://github.com/Rohan138)
"""

import argparse
import os

import numpy as np
import ray
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from rllib_leduc_holdem import TorchMaskedActions

from pettingzoo.classic import leduc_holdem_v4

os.environ["SDL_VIDEODRIVER"] = "dummy"

parser = argparse.ArgumentParser(
    description="Render pretrained policy loaded from checkpoint"
)
parser.add_argument(
    "--checkpoint-path",
    help="Path to the checkpoint. This path will likely be something like this: `~/ray_results/pistonball_v6/PPO/PPO_pistonball_v6_660ce_00000_0_2021-06-11_12-30-57/checkpoint_000050/checkpoint-50`",
)

args = parser.parse_args()

checkpoint_path = os.path.expanduser(args.checkpoint_path)

alg_name = "DQN"
ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.

def env_creator():
    env = leduc_holdem_v4.env()
    return env

env = env_creator()
env_name = "pistonball_v6"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))

# env = env_creator()
# env_name = "leduc_holdem_v4"
# register_env(env_name, lambda config: PettingZooEnv(env_creator()))

ray.init()
DQNAgent = Algorithm.from_checkpoint(checkpoint_path)

reward_sums = {a: 0 for a in env.possible_agents}
i = 0
env.reset()

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    obs = observation["observation"]
    reward_sums[agent] += reward
    if termination or truncation:
        action = None
    else:
        print(DQNAgent.get_policy(agent))
        policy = DQNAgent.get_policy(agent)
        batch_obs = {
            "obs": {
                "observation": np.expand_dims(observation["observation"], 0),
                "action_mask": np.expand_dims(observation["action_mask"], 0),
            }
        }
        batched_action, state_out, info = policy.compute_actions_from_input_dict(
            batch_obs
        )
        single_action = batched_action[0]
        action = single_action

    env.step(action)
    i += 1
    env.render()

print("rewards:")
print(reward_sums)

For context on the working example with pistonball, here is rllib_pistonball.py:

"""Uses Ray's RLLib to train agents to play Pistonball.

Author: Rohan (https://github.com/Rohan138)
"""

import os

import ray
import supersuit as ss
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn

from pettingzoo.butterfly import pistonball_v6

class CNNModelV2(TorchModelV2, nn.Module):
    def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
        TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
        nn.Module.__init__(self)
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, [8, 8], stride=(4, 4)),
            nn.ReLU(),
            nn.Conv2d(32, 64, [4, 4], stride=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(64, 64, [3, 3], stride=(1, 1)),
            nn.ReLU(),
            nn.Flatten(),
            (nn.Linear(3136, 512)),
            nn.ReLU(),
        )
        self.policy_fn = nn.Linear(512, num_outputs)
        self.value_fn = nn.Linear(512, 1)

    def forward(self, input_dict, state, seq_lens):
        model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2))
        self._value_out = self.value_fn(model_out)
        return self.policy_fn(model_out), state

    def value_function(self):
        return self._value_out.flatten()

def env_creator(args):
    env = pistonball_v6.parallel_env(
        n_pistons=20,
        time_penalty=-0.1,
        continuous=True,
        random_drop=True,
        random_rotate=True,
        ball_mass=0.75,
        ball_friction=0.3,
        ball_elasticity=1.5,
        max_cycles=125,
    )
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.dtype_v0(env, "float32")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.normalize_obs_v0(env, env_min=0, env_max=1)
    env = ss.frame_stack_v1(env, 3)
    return env

if __name__ == "__main__":
    ray.init(local_mode=True)

    env_name = "pistonball_v6"

    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
    ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2)

    config = (
        PPOConfig()
        .environment(env=env_name, clip_actions=True)
        .rollouts(num_rollout_workers=4, rollout_fragment_length=128)
        .training(
            train_batch_size=512,
            lr=2e-5,
            gamma=0.99,
            lambda_=0.9,
            use_gae=True,
            clip_param=0.4,
            grad_clip=None,
            entropy_coeff=0.1,
            vf_loss_coeff=0.25,
            sgd_minibatch_size=64,
            num_sgd_iter=10,
        )
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

    tune.run(
        "PPO",
        name="PPO",
        stop={"timesteps_total": 5000000},
        checkpoint_freq=10,
        local_dir="~/ray_results/" + env_name,
        config=config.to_dict(),
    )

And render_rllib_pistonball.py (load script):

"""Uses Ray's RLLib to view trained agents playing Pistonball.

Author: Rohan (https://github.com/Rohan138)
"""

import argparse
import os

import ray
import supersuit as ss
from PIL import Image
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from tutorials.Ray.rllib_pistonball import CNNModelV2

from pettingzoo.butterfly import pistonball_v6

os.environ["SDL_VIDEODRIVER"] = "dummy"

parser = argparse.ArgumentParser(
    description="Render pretrained policy loaded from checkpoint"
)
parser.add_argument(
    "--checkpoint-path",
    help="Path to the checkpoint. This path will likely be something like this: `~/ray_results/pistonball_v6/PPO/PPO_pistonball_v6_660ce_00000_0_2021-06-11_12-30-57/checkpoint_000050/checkpoint-50`",
)

args = parser.parse_args()

checkpoint_path = os.path.expanduser(args.checkpoint_path)

ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2)

def env_creator():
    env = pistonball_v6.env(
        n_pistons=20,
        time_penalty=-0.1,
        continuous=True,
        random_drop=True,
        random_rotate=True,
        ball_mass=0.75,
        ball_friction=0.3,
        ball_elasticity=1.5,
        max_cycles=125,
        render_mode="RGB_array",
    )
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.dtype_v0(env, "float32")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.normalize_obs_v0(env, env_min=0, env_max=1)
    env = ss.frame_stack_v1(env, 3)
    return env

env = env_creator()
env_name = "pistonball_v6"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))

ray.init()

PPOagent = PPO.from_checkpoint(checkpoint_path)

reward_sum = 0
frame_list = []
i = 0
env.reset()

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    reward_sum += reward
    if termination or truncation:
        action = None
    else:
        action = PPOagent.compute_single_action(observation)

    env.step(action)
    i += 1
    if i % (len(env.possible_agents) + 1) == 0:
        img = Image.fromarray(env.render())
        frame_list.append(img)
env.close()

print(reward_sum)
frame_list[0].save(
    "out.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0
)

Issue Severity

High: It blocks me from completing my task.

ChaceAshcraft commented 1 year ago

I'm having a similar issue with VizDoom. If I run the RLlib VizDoom example with PPO, it is fine, but if I want to run DQN it tells me that the VizDoom environment doesn't exist.

elliottower commented 1 year ago

I'm having a similar issue with VizDoom. If I run the RLlib VizDoom example with PPO, it is fine, but if I want to run DQN it tells me that the VizDoom environment doesn't exist.

Interesting to hear it can depend on the model, maybe I’ll try out some different models and see if that’s the issue. DQN is the case where I’m getting errors as well. Very weird that a model would have any bearing on environment loading though.