ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.01k stars 5.59k forks source link

[rllib] LR not updated by LearningRateSchedule in TorchPolicy #12163

Closed Manuscrit closed 3 years ago

Manuscrit commented 3 years ago

What is the problem?

In a TorchPolicy (I did not test with TFPolicy) with LearningRateSchedule mixin, the lr value of the optimizers lr_group are not updated (LR stays constant). The cur_lr attribute has the right value but the optimizers are being updated to use it. This may be related/similar to #6101 and #10423. The optimizer method of the Policy is not overridden when using the LearningRateSchedule mixin. Tested with DQN and PPO (PyTorch only).

Ray version and other system information (Python version, TensorFlow version, OS): ray version 1.0.0 and nightly python 3.8.5 torch 1.7.0

Reproduction (REQUIRED)

Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):

import argparse

import ray
from ray import tune
from ray.rllib.examples.env.parametric_actions_cartpole import \
    ParametricActionsCartPole
from ray.rllib.examples.models.parametric_actions_model import \
    ParametricActionsModel, TorchParametricActionsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.registry import register_env
from ray.rllib.agents.callbacks import DefaultCallbacks

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--torch", action="store_true")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-reward", type=float, default=150.0)
parser.add_argument("--stop-timesteps", type=int, default=100000)

class MyCallbacks(DefaultCallbacks):

    def on_train_result(self, *, trainer, result: dict, **kwargs):
        def print_true_lr(policy, policy_id):
            print("policy.cur_lr", policy.cur_lr)
            for j, opt in enumerate(policy._optimizers):
                print(policy_id, j, [p["lr"] for p in opt.param_groups])
                for p in opt.param_groups:
                    assert p["lr"] == policy.cur_lr, f'should have p["lr"] == policy.cur_lr but got p["lr"] ' \
                                                     f'= {p["lr"]} and policy.cur_lr {policy.cur_lr}'
        trainer.workers.foreach_policy(print_true_lr)

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
    ModelCatalog.register_custom_model(
        "pa_model", TorchParametricActionsModel
        if args.torch else ParametricActionsModel)

    if args.run == "DQN":
        cfg = {
            # TODO(ekl) we need to set these to prevent the masked values
            # from being further processed in DistributionalQModel, which
            # would mess up the masking. It is possible to support these if we
            # defined a custom DistributionalQModel that is aware of masking.
            "hiddens": [],
            "dueling": False,
        }
    else:
        cfg = {}

    config = dict({
        "env": "pa_cartpole",
        "model": {
            "custom_model": "pa_model",
        },
        "num_workers": 0,
        "framework": "torch" if args.torch else "tf",

        "callbacks": MyCallbacks,
        # === Optimization ===
        # Learning rate for adam optimizer
        "lr": 1e-3,
        # Learning rate schedule
        "lr_schedule": [(0, 1e-3), (20000, 3e-5 / 1e9)],

    }, **cfg)

    stop = {
        "training_iteration": args.stop_iters,
        "timesteps_total": args.stop_timesteps,
        "episode_reward_mean": args.stop_reward,
    }

    results = tune.run(args.run, stop=stop, config=config, verbose=1)

    if args.as_test:
        check_learning_achieved(results, args.stop_reward)

    ray.shutdown()

If we cannot run your script, we cannot fix your issue.

sven1977 commented 3 years ago

Thanks for filing this issue @Manuscrit PR: https://github.com/ray-project/ray/pull/12396