opendilab / DI-engine

OpenDILab Decision AI Engine. The Most Comprehensive Reinforcement Learning Framework B.P.
https://di-engine-docs.readthedocs.io
Apache License 2.0
3k stars 366 forks source link

TD3应用混合动作空间报错,AssertionError #789

Closed dajianer closed 5 months ago

dajianer commented 5 months ago

在使用TD3训练混合动作空间环境时,运行会报错assert isinstance(action, torch.Tensor),我查看源码发现HybridArgmaxSampleWrapper的forward返回值确实可能会引起错误,请问我应该怎样解决呢 代码如下:

    logging.getLogger().setLevel(logging.INFO)
    cfg = compile_config(main_config, create_cfg=create_config, auto=True)
    ding_init(cfg)
    ctx = OnlineRLContext(collect_kwargs={'eps': 0.01})
    with task.start(async_mode=False, ctx=ctx):
        collector_env = BaseEnvManagerV2(
            env_fn=[lambda: DI_UAV_AoI(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
        )
        evaluator_env = BaseEnvManagerV2(
            env_fn=[lambda: DI_UAV_AoI(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
        )

        set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

        model = ContinuousQAC(**cfg.policy.model)
        buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
        policy = TD3Policy(cfg.policy, model=model)

        task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
        task.use(
            StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
        )
        task.use(data_pusher(cfg, buffer_))
        task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
        task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
        task.use(termination_checker(max_train_iter=int(1e5)))
        task.use(online_logger())
        task.run(max_step=int(1e5))
PaParaZz1 commented 5 months ago

请问你使用混合动作空间时,有没有指定 TD3Policy 中的 action_space='hybrid'

具体的相关实现在这里(链接

完整的 DDPG/TD3 类型的混合动作空间配置文件可以参考这个示例

MarkHolmstrom commented 1 month ago

The issue appears to be TD3's target policy smoothing, with the action noise wrapper not supporting the hybrid action space. Setting noise=False in the policy configuration to be the same as the reference DDPG config disables target policy smoothing as a workaround.