Closed alexxchen closed 3 weeks ago
Hi @alexxchen. It is quite normal!
Not all problems are solvable by current baselines provided, otherwise it would be boring to explore something new! Therefore, one should not expect that current methods can solve absolutely all problems. Even I with more sophisticated methods (not public yet) only manage to solve a small part due to too sparse reward and complexity of the exploration.
Other than that, all trivial tasks should be solvable, but you need to train for 500M transitions as least (try to increase num_envs to 8k or higher). You can provide intermediate rewards, but it is up to you to decide what they should be, as the original environments will remain with sparse rewards (as it is closer to real applications).
For example, something along these lines should solve majority of trivial tasks:
env_id: "XLand-MiniGrid-R1-9x9"
benchmark_id: "trivial-1m"
ruleset_id: 0
total_timesteps: 500_000_000
num_envs: 8192
num_steps: 256
num_minibatches: 8
gae_lambda: 0.999
gamma: 0.999
Should take approx 5 minutes on A100 GPU for a single run.
@Howuhh Thanks for your reply! I try to increase num_envs and see slightly increase in Final return.
Final return: 0.023287037387490273
wandb:
wandb:
wandb: Run history:
wandb: actor_loss ▁▃▂▇▆▆█▆▇
wandb: entropy █▅▄▄▄▃▂▁▁
wandb: eval/lengths ▆▅▅▂▆▁█▄▁
wandb: eval/returns ▄▄▅█▃█▁▆█
wandb: lr █▇▆▅▅▄▃▂▁
wandb: total_loss █▁▁▁▁▁▁▁▁
wandb: transitions ▁▂▃▄▅▅▆▇█
wandb: value_loss █▁▁▁▁▁▁▁▁
wandb:
wandb: Run summary:
wandb: actor_loss -2e-05
wandb: entropy 1.74209
wandb: eval/lengths 237.3875
wandb: eval/returns 0.02329
wandb: lr 0.00011
wandb: steps_per_second 2474.68162
wandb: total_loss -0.01736
wandb: training_time 4040.92386
wandb: transitions 9437184
wandb: value_loss 0.00016
But the success rate is still poor (around 3%, almost chance level) using eval.py. Here are some modifications on eval.py to calculate success rate:
import imageio
import jax
import jax.numpy as jnp
import orbax.checkpoint
import xminigrid
from nn import ActorCriticRNN
from xminigrid.rendering.text_render import print_ruleset
from xminigrid.wrappers import GymAutoResetWrapper
TOTAL_EPISODES = 100
def main():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint = orbax_checkpointer.restore("./training/checkpoints")
config = checkpoint["config"]
params = checkpoint["params"]
env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9")
env = GymAutoResetWrapper(env)
ruleset = xminigrid.load_benchmark("trivial-1m").get_ruleset(0)
env_params = env_params.replace(ruleset=ruleset)
model = ActorCriticRNN(
num_actions=env.num_actions(env_params),
action_emb_dim=config["action_emb_dim"],
rnn_hidden_dim=config["rnn_hidden_dim"],
rnn_num_layers=config["rnn_num_layers"],
head_hidden_dim=config["head_hidden_dim"],
)
# jitting all functions
apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step)
# initial inputs
prev_reward = jnp.asarray(0)
prev_action = jnp.asarray(0)
hidden = model.initialize_carry(1)
# for logging
total_reward, num_episodes, success = 0, 0, 0
rendered_imgs = []
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
timestep = reset_fn(env_params, _rng)
rendered_imgs.append(env.render(env_params, timestep))
while num_episodes < TOTAL_EPISODES:
rng, _rng = jax.random.split(rng)
dist, value, hidden = apply_fn(
params,
{
"observation": timestep.observation[None, None, ...],
"prev_action": prev_action[None, None, ...],
"prev_reward": prev_reward[None, None, ...],
},
hidden,
)
action = dist.sample(seed=_rng).squeeze()
timestep = step_fn(env_params, timestep, action)
prev_action = action
prev_reward = timestep.reward
success += int(timestep.reward.item() > 0)
total_reward += timestep.reward.item()
num_episodes += int(timestep.last().item())
rendered_imgs.append(env.render(env_params, timestep))
print("Total reward:", total_reward)
print("Success Rate:", success / TOTAL_EPISODES)
print_ruleset(ruleset)
# imageio.mimsave("rollout.mp4", rendered_imgs, fps=8, format="mp4")
# imageio.mimsave("rollout.gif", rendered_imgs, duration=1000 * 1 / 8, format="gif")
if __name__ == "__main__":
main()
The cluster I use does not support the new version of cuda required by xland-minigird yet. Thus the result is from
num_envs: int = 8192
num_steps: int = 128
total_timesteps: int = 1_000_000_0
I would like to know what the success rate of your test results is? Thanks a lot!!
@alexxchen this is strange! I will re-run it (approx tomorrow) and will release wandb logs + config.
How it is going?
@alexxchen re-trained on trivial-1m, ruleset-id=0 with the config above: https://wandb.ai/state-machine/xminigrid/groups/trivial-tmp/workspace?nw=nwuserstatemachine
seems normal, easily solves this ruleset in a 1B transitions, took only 12 minutes. With some tuning I think will solve in a 500B transitions too.
P.S. eval.py is provided as an example for meta-RL policies, not for single task, as evaluation for them is straightforward (as in any other method). You can use rollout
from utils.py.
@alexxchen has this solved your problem?
@alexxchen has this solved your problem?
Yes! Exactly! I'm just re-running the experiment on CPU to figure out what's the matter with eval.py. Thank you very much for your reply!
In my final result. The success rate is 100%. So the difficulty to learn task-0 ( Goal 8) is due to the sparser reward compared with other task. But the difficulty can be overcomed by extremely large timesteps.
@alexxchen good news! Good luck with your experiments
I tried train_single_task.py and modify the config into:
The final result are
It seems that the ruleset_id=0 is difficult to learn compared to ruleset_id=1 (return with less total_timesteps is 0.46). The same situation for ruleset_id=3. I found out that ruleset_id=0 and ruleset_id=3 share the same TileNearRightGoal task. I guess the tile task needs more operations than others, and it is comprised of AgentHoldGoal, AgentNearGoal and AgentOnTileGoal? I don't know if it is possible to add intermediate reward for the tile tasks.