Closed hanshuo-shuo closed 8 months ago
Hi @hanshuo-shuo, yeah you're right, the test
function is run once after the training. Can you please try this instead:
import pathlib
from functools import partial
import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models
from sheeprl.algos.dreamer_v3.utils import test
from sheeprl.envs.wrappers import RestartOnException
from sheeprl.utils.env import make_env
if __name__ == "__main__":
fabric = Fabric(accelerator="gpu", devices="1")
ckpt_path = pathlib.Path("~/sheeprl/logs/runs/dreamer_v3/{date}/{env_id}/version_0/checkpoint/ckpt_1000000_0.ckpt")
cfg = OmegaConf.load("/home/federico.belotti/sheeprl/logs/runs/dreamer_v3/{date}/{env_id}/.hydra/config.yaml")
ckpt = fabric.load(ckpt_path)
# Change configs as needed
cfg.seed = 0
cfg.env.num_envs = 1
cfg.env.capture_video = True
# Recreate a single environment
rank = 0
vectorized_env = gym.vector.SyncVectorEnv
envs = vectorized_env(
[
partial(
RestartOnException,
env_fn=make_env(
cfg,
cfg.seed + rank * cfg.env.num_envs + i,
rank * cfg.env.num_envs,
None,
"",
vector_env_idx=i,
),
)
for i in range(1)
],
)
action_space = envs.single_action_space
observation_space = envs.single_observation_space
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = (
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.clip_rewards else r
cnn_keys = cfg.cnn_keys.encoder
mlp_keys = cfg.mlp_keys.encoder
fabric.print("CNN keys:", cnn_keys)
fabric.print("MLP keys:", mlp_keys)
obs_keys = cnn_keys + mlp_keys
# Close the environment since it will be recreated inside the `test` function
envs.close()
# Create models and load weights from checkpoint
world_model, actor, critic, target_critic = build_models(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
ckpt["world_model"],
ckpt["actor"],
ckpt["critic"],
ckpt["target_critic"],
)
# Create the player agent
player = PlayerDV3(
world_model.encoder.module,
world_model.rssm,
actor.module,
actions_dim,
cfg.algo.player.expl_amount,
cfg.env.num_envs,
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
)
# Test the agent
sample_actions = True # Whether to sample actions from the actor's distribution or not
test(
player,
fabric,
cfg,
f"{cfg.env.id}_{cfg.seed}_sample_{sample_actions}",
sample_actions=sample_actions,
)
You just need to change the checkpoint path and it should run
@belerico, we could add a template (script or notebook) for evaluating the trained models, what do you think?
Yes. That was working, thanks a lot~
@belerico, we could add a template (script or notebook) for evaluating the trained models, what do you think?
I agree. It should be algorithm dependent, for how the library is organized. Maybe we can have something like this:
register_eval
to register an evaluation function<algo>/evaluate.py
with a standard naming, like def evaluate(...)
where the args has to be definedsheeprl_evaluate.py
where we load the hydra confs, load the checkpoint and call the specific evaluate
functionWhat do you think? I can have something prepared so that we can discuss
I agree. It should be algorithm dependent, for how the library is organized. Maybe we can have something like this:
- create a decorator like
register_eval
to register an evaluation function- the evaluation function should be placed under
<algo>/evaluate.py
with a standard naming, likedef evaluate(...)
where the args has to be defined- we should have a
sheeprl_evaluate.py
where we load the hydra confs, load the checkpoint and call the specificevaluate
functionWhat do you think? I can have something prepared so that we can discuss
An algorithm-dependent evaluation script is a good idea because it will be easy for users to evaluate their agents.
@hanshuo-shuo Can you try out the code of the linked PR?
Hi, Thanks for the efficient training. I wonder after training, if there is an easy way to implement the evaluation. I know you have a test function after the training, but that has only one round and is also quite hard for me to modify LOL. Thanks~