opendilab / DI-engine

OpenDILab Decision AI Engine. The Most Comprehensive Reinforcement Learning Framework B.P.
https://di-engine-docs.readthedocs.io
Apache License 2.0
3k stars 366 forks source link

feature(nyz): adapt DingEnvWrapper to gymnasium #817

Closed PaParaZz1 closed 3 months ago

PaParaZz1 commented 3 months ago

Description

Examples

import gymnasium as gym
from ditk import logging
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
    eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config

def main():
    logging.getLogger().setLevel(logging.INFO)
    main_config.exp_name = 'cartpole_dqn_nstep_gymnasium'
    main_config.policy.nstep = 3
    cfg = compile_config(main_config, create_cfg=create_config, auto=True)
    with task.start(async_mode=False, ctx=OnlineRLContext()):
        collector_env = BaseEnvManagerV2(
            env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
            cfg=cfg.env.manager
        )
        evaluator_env = BaseEnvManagerV2(
            env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
            cfg=cfg.env.manager
        )

        set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

        model = DQN(**cfg.policy.model)
        buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
        policy = DQNPolicy(cfg.policy, model=model)

        task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
        task.use(eps_greedy_handler(cfg))
        task.use(StepCollector(cfg, policy.collect_mode, collector_env))
        task.use(nstep_reward_enhancer(cfg))
        task.use(data_pusher(cfg, buffer_))
        task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
        task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
        task.use(final_ctx_saver(cfg.exp_name))
        task.run()

if __name__ == "__main__":
    main()

Check List

codecov[bot] commented 3 months ago

Codecov Report

Attention: Patch coverage is 97.22222% with 1 line in your changes missing coverage. Please review.

Project coverage is 75.99%. Comparing base (7f95159) to head (701152e).

Files Patch % Lines
ding/envs/env/ding_env_wrapper.py 95.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #817 +/- ## ========================================== + Coverage 75.94% 75.99% +0.04% ========================================== Files 684 684 Lines 55597 55607 +10 ========================================== + Hits 42224 42256 +32 + Misses 13373 13351 -22 ``` | [Flag](https://app.codecov.io/gh/opendilab/DI-engine/pull/817/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=opendilab) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/opendilab/DI-engine/pull/817/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=opendilab) | `75.99% <97.22%> (+0.04%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=opendilab#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

kxzxvbk commented 3 months ago

I think there is no problem for this pr