corl-team / xland-minigrid

JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid 🏎️
Apache License 2.0
162 stars 12 forks source link

Fail to optimize single tasks in the given demo. Maybe intermediate reward is needed. #21

Closed alexxchen closed 3 weeks ago

alexxchen commented 1 month ago

I tried train_single_task.py and modify the config into:

    env_id: str = "XLand-MiniGrid-R1-9x9"
    benchmark_id: Optional[str] = "trivial-1m"
    ruleset_id: Optional[int] = 0
    num_envs: int = 32
    total_timesteps: int = 1_00000

The final result are

Compiling...
Done in 32.63s.
Training...
Done in 1308.51s
Logging...
Final return:  0.0
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:   actor_loss ▆▆█▆▆▄▅▅▄▅▄▁▄▄▄▃▃▃▃▆▄▄▃▄▂▃▄▃▃▃▄▃▃▄▅▄▄▃▃▃
wandb:      entropy █▇▆▅▅▆▅▅▄▃▁▂▂▂▁▁▁▁▁▁▂▁▂▂▂▂▁▁▁▁▁▁▁▁▂▂▂▂▂▂
wandb: eval/lengths ▇███▄█▇▅▁███▅█▄▃█████████████████▃█▃████
wandb: eval/returns ▂▂▁▁▅▁▃▄█▁▁▁▄▁▅▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▆▁▁▁▁
wandb:           lr ███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
wandb:   total_loss ▄▄▇▄▄▁▄▄▃▆▆▁▅▆▅▄▅▃▄█▅▅▃▄▂▄▅▄▅▅▆▄▄▅▆▅▅▄▄▄
wandb:  transitions ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:   value_loss █▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:       actor_loss 9e-05
wandb:          entropy 0.3547
wandb:     eval/lengths 243.0
wandb:     eval/returns 0.0
wandb:               lr 1e-05
wandb: steps_per_second 76.42288
wandb:       total_loss -0.00346
wandb:    training_time 1308.5086
wandb:      transitions 99840
wandb:       value_loss 0.0

It seems that the ruleset_id=0 is difficult to learn compared to ruleset_id=1 (return with less total_timesteps is 0.46). The same situation for ruleset_id=3. I found out that ruleset_id=0 and ruleset_id=3 share the same TileNearRightGoal task. I guess the tile task needs more operations than others, and it is comprised of AgentHoldGoal, AgentNearGoal and AgentOnTileGoal? I don't know if it is possible to add intermediate reward for the tile tasks.

Howuhh commented 1 month ago

Hi @alexxchen. It is quite normal!

Not all problems are solvable by current baselines provided, otherwise it would be boring to explore something new! Therefore, one should not expect that current methods can solve absolutely all problems. Even I with more sophisticated methods (not public yet) only manage to solve a small part due to too sparse reward and complexity of the exploration.

Other than that, all trivial tasks should be solvable, but you need to train for 500M transitions as least (try to increase num_envs to 8k or higher). You can provide intermediate rewards, but it is up to you to decide what they should be, as the original environments will remain with sparse rewards (as it is closer to real applications).

Howuhh commented 1 month ago

For example, something along these lines should solve majority of trivial tasks:

env_id: "XLand-MiniGrid-R1-9x9"
benchmark_id: "trivial-1m"
ruleset_id: 0
total_timesteps: 500_000_000
num_envs: 8192
num_steps: 256
num_minibatches: 8
gae_lambda: 0.999
gamma: 0.999

Should take approx 5 minutes on A100 GPU for a single run.

alexxchen commented 1 month ago

@Howuhh Thanks for your reply! I try to increase num_envs and see slightly increase in Final return.

Final return:  0.023287037387490273
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:   actor_loss ▁▃▂▇▆▆█▆▇
wandb:      entropy █▅▄▄▄▃▂▁▁
wandb: eval/lengths ▆▅▅▂▆▁█▄▁
wandb: eval/returns ▄▄▅█▃█▁▆█
wandb:           lr █▇▆▅▅▄▃▂▁
wandb:   total_loss █▁▁▁▁▁▁▁▁
wandb:  transitions ▁▂▃▄▅▅▆▇█
wandb:   value_loss █▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:       actor_loss -2e-05
wandb:          entropy 1.74209
wandb:     eval/lengths 237.3875
wandb:     eval/returns 0.02329
wandb:               lr 0.00011
wandb: steps_per_second 2474.68162
wandb:       total_loss -0.01736
wandb:    training_time 4040.92386
wandb:      transitions 9437184
wandb:       value_loss 0.00016

But the success rate is still poor (around 3%, almost chance level) using eval.py. Here are some modifications on eval.py to calculate success rate:

import imageio
import jax
import jax.numpy as jnp
import orbax.checkpoint
import xminigrid
from nn import ActorCriticRNN
from xminigrid.rendering.text_render import print_ruleset
from xminigrid.wrappers import GymAutoResetWrapper

TOTAL_EPISODES = 100

def main():
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    checkpoint = orbax_checkpointer.restore("./training/checkpoints")
    config = checkpoint["config"]
    params = checkpoint["params"]

    env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9")
    env = GymAutoResetWrapper(env)

    ruleset = xminigrid.load_benchmark("trivial-1m").get_ruleset(0)
    env_params = env_params.replace(ruleset=ruleset)

    model = ActorCriticRNN(
        num_actions=env.num_actions(env_params),
        action_emb_dim=config["action_emb_dim"],
        rnn_hidden_dim=config["rnn_hidden_dim"],
        rnn_num_layers=config["rnn_num_layers"],
        head_hidden_dim=config["head_hidden_dim"],
    )
    # jitting all functions
    apply_fn, reset_fn, step_fn = jax.jit(model.apply), jax.jit(env.reset), jax.jit(env.step)

    # initial inputs
    prev_reward = jnp.asarray(0)
    prev_action = jnp.asarray(0)
    hidden = model.initialize_carry(1)

    # for logging
    total_reward, num_episodes, success = 0, 0, 0
    rendered_imgs = []

    rng = jax.random.PRNGKey(0)
    rng, _rng = jax.random.split(rng)

    timestep = reset_fn(env_params, _rng)
    rendered_imgs.append(env.render(env_params, timestep))
    while num_episodes < TOTAL_EPISODES:
        rng, _rng = jax.random.split(rng)
        dist, value, hidden = apply_fn(
            params,
            {
                "observation": timestep.observation[None, None, ...],
                "prev_action": prev_action[None, None, ...],
                "prev_reward": prev_reward[None, None, ...],
            },
            hidden,
        )
        action = dist.sample(seed=_rng).squeeze()

        timestep = step_fn(env_params, timestep, action)
        prev_action = action
        prev_reward = timestep.reward

        success += int(timestep.reward.item() > 0)
        total_reward += timestep.reward.item()
        num_episodes += int(timestep.last().item())

        rendered_imgs.append(env.render(env_params, timestep))

    print("Total reward:", total_reward)
    print("Success Rate:", success / TOTAL_EPISODES)
    print_ruleset(ruleset)
    # imageio.mimsave("rollout.mp4", rendered_imgs, fps=8, format="mp4")
    # imageio.mimsave("rollout.gif", rendered_imgs, duration=1000 * 1 / 8, format="gif")

if __name__ == "__main__":
    main()

The cluster I use does not support the new version of cuda required by xland-minigird yet. Thus the result is from

num_envs: int = 8192
num_steps: int = 128
total_timesteps: int = 1_000_000_0

I would like to know what the success rate of your test results is? Thanks a lot!!

Howuhh commented 1 month ago

@alexxchen this is strange! I will re-run it (approx tomorrow) and will release wandb logs + config.

alexxchen commented 1 month ago

How it is going?

Howuhh commented 1 month ago

@alexxchen re-trained on trivial-1m, ruleset-id=0 with the config above: https://wandb.ai/state-machine/xminigrid/groups/trivial-tmp/workspace?nw=nwuserstatemachine

seems normal, easily solves this ruleset in a 1B transitions, took only 12 minutes. With some tuning I think will solve in a 500B transitions too.

P.S. eval.py is provided as an example for meta-RL policies, not for single task, as evaluation for them is straightforward (as in any other method). You can use rollout from utils.py.

Howuhh commented 1 month ago

@alexxchen has this solved your problem?

alexxchen commented 1 month ago

@alexxchen has this solved your problem?

Yes! Exactly! I'm just re-running the experiment on CPU to figure out what's the matter with eval.py. Thank you very much for your reply!

alexxchen commented 4 weeks ago

In my final result. The success rate is 100%. So the difficulty to learn task-0 ( Goal 8) is due to the sparser reward compared with other task. But the difficulty can be overcomed by extremely large timesteps.

Howuhh commented 3 weeks ago

@alexxchen good news! Good luck with your experiments