rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.84k stars 309 forks source link

How do I add a new meta environment? #2291

Closed tianyma closed 3 years ago

tianyma commented 3 years ago

Dear author, thank you for your excellent work, could you tell me how do I add a new environment for meta learning, is there any necessary, can you post an example on the doc?

krzentner commented 3 years ago

Hi tianyma,

How to add a new meta environment depends on the API that environment uses. I recommend you look at the existing examples showing how to use the meta HalfCheetah environments and Meta-World benchmarks.

Generally speaking, you'll want to create a TaskSampler to use a meta environment, although existing TaskSamplers exist for commonly used meta-RL APIs, like the set_task API.

tianyma commented 3 years ago

Sorry, I am quite new here, I want to know If I use the 2d navigation environmentPointEnv, should I add a wrapper such as GymEnv when calling the env, or I can directly construct the environment?

ryanjulian commented 3 years ago

Hi @tianyma -- PointEnv inherits directly from Environment, so no wrapper should be necessary.

tianyma commented 3 years ago

Hi @tianyma -- PointEnv inherits directly from Environment, so no wrapper should be necessary.

Hi, @ryanjulian. Thank you very much for your reply. I simply modify the file maml_trpo_half_cheetah_dir.py to adapt point environment. So I simply modify the following code:

env = normalize(GymEnv(HalfCheetahDirEnv(),
             max_episode_length=max_episode_length),
             expected_action_scale=10.)

to

    goal = np.array([1., 1.])
    env = PointEnv(goal=goal, max_episode_length=max_episode_length)

and then I modify the task_sampler from

    task_sampler = SetTaskSampler(
        HalfCheetahDirEnv,
        wrapper=lambda env, _: normalize(GymEnv(
            env, max_episode_length=max_episode_length),
                                         expected_action_scale=10.)

to

    task_sampler = SetTaskSampler(PointEnv)

then I directly run the code, but I seems cannot get any result, can you help me?