Open jesuspc opened 11 months ago
Great catch @jesuspc ! Thanks for pointing us towards this issue. I have looked into it a bit deeper and found a little twist in the docs which I will change later.
I can follow your argumentation. Afaics in the source code connectors are considered in compute_single_action
. The question here is why they do not apply. I try to dig deeper in the next hours. So thanks again for reporting this.
For now, could you try (as shown in the docs, but forgetting the policy id) use this and tell me, if this makes it easier for you?
import gymnasium as gym
from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.policy import local_policy_inference
from ray import air, tune
config = (
APPOConfig()
.environment(
env="CartPole-v1",
)
.rollouts(
num_envs_per_worker=5,
num_rollout_workers=1,
observation_filter="MeanStdFilter",
# Default.
enable_connectors=True,
)
.training(
num_sgd_iter=6,
vf_loss_coeff=0.01,
vtrace=False,
model={
"fcnet_hiddens": [32],
"fcnet_activation": "linear",
"vf_share_layers": False,
},
)
)
tuner = tune.Tuner(
"APPO",
run_config=air.RunConfig(
name="test_appo_issue",
stop={
"sampler_results/episode_reward_mean": 150,
"timesteps_total": 200000
}
),
param_space=config,
)
tuner.fit()
chkpt = "<HOME>/ray_results/test_appo_issue/folder_name/checkpoint_000000"
policy = Policy.from_checkpoint(chkpt)
env = gym.make("CartPole-v1")
obs, info = env.reset()
terminated = truncated = False
step = 0
while not terminated and not truncated:
step += 1
# Use local_policy_inference() to run inference, so we do not have to
# provide policy states or extra fetch dictionaries.
# "env_1" and "agent_1" are dummy env and agent IDs to run connectors with.
policy_outputs = local_policy_inference(
policy["default_policy"], "env_1", "agent_1", obs, explore=False
)
assert len(policy_outputs) == 1
action, _, _ = policy_outputs[0]
print(f"step {step}", obs, action)
# Step environment forward one more step.
obs, _, terminated, truncated, _ = env.step(action)
This should at least use the connectors.
Hi @simonsays1980, has there been any progress in fixing this or should I use the example you shared?
If the issue is not yet fixed, I personally favor using policy.compute_single_action(...)
action = policy.compute_single_action(obs, explore=False, prev_action=xx, prev_reward==yy) # ...
because the method has a
prev_action
param, which is useful in my case whereViewRequirements
requires this information. Will connectors be applied in this setup?
Hi @simonsays1980, I've tried to use the local_policy_inference
method, but it will raise an exception on the second step. I think the problem is caused here:
It seems like on the first call, some missing data (e.g. "rewards") will be padded. However, on the second call it is simply [[]]
, which caused index error when doing this caching.
UPDATE:
I found out this is likely caused by rl_module
? If I forced PPO to use ModelV2 (PPOConfig().rl_module(_enable_rl_module_api=False).training(_enable_learner_api=False)
), the code can run without a problem. However, since I already trained the model, is there a way to fix it?
What happened + What you expected to happen
As described originally in a comment to an existing issue, a model trained with
MeanStdFilter
does not apply that filtering logic during inference onalgo.compute_single_action
. I have encountered the issue with bothAPPO
andTD3
- it's likely to affect other algorithms as well.For instance:
Does not reproduce the results seen during training when reusing for inference:
In order to reproduce training results the connector needs to be called explicitly:
I would have expected the filtering logic to be applied automatically instead - I am not sure if this only affects the MeanStd connector or others are affected as well.
Versions / Dependencies
Ray: 2.7.0 Python: 3.10.11 OS: Ubuntu 22.04
Reproduction script
Snippets included in the previous section.
Issue Severity
Medium: It is a significant difficulty but I can work around it.