ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.09k stars 5.6k forks source link

[RLlib] Tuner.restore() Not Restoring Training #43266

Open GoodarzMehr opened 7 months ago

GoodarzMehr commented 7 months ago

What happened + What you expected to happen

I am using RLlib's SAC with a multi-agent environment that crashes from time to time due to memory issues. Using Tuner.restore() (or alternatively Algorithm.save_checkpoint() and Algorithm.from_checkpoint()) does not restore the training, instead re-initializing it, as can be seen in the image below.

image

See here for further information.

Versions / Dependencies

Ubuntu 22.04 Python 3.8.10 Ray 2.9.2 Torch 1.10.1 (cu113) CUDA 11.3

Reproduction script

I am using the following script:

import os
import ray
import yaml
import time
import argparse

from tensorboard import program

from ray import air, tune

from ray.tune.registry import register_env

from carla_env import CarlaEnv

argparser = argparse.ArgumentParser(description='CoPeRL Training Implementation.')

argparser.add_argument('config', help='configuration file')
argparser.add_argument('-d', '--directory',
                       metavar='D',
                       default='/home/coperl/ray_results',
                       help='directory to save the results (default: /home/coperl/ray_results)')
argparser.add_argument('-n', '--name',
                       metavar='N',
                       default='sac_experiment',
                       help='name of the experiment (default: sac_experiment)')
argparser.add_argument('--restore',
                       action='store_true',
                       default=False,
                       help='restore the specified experiment (default: False)')
argparser.add_argument('--tb',
                       action='store_true',
                       default=False,
                       help='activate tensorboard (default: False)')

args = argparser.parse_args()

def parse_config(args):
    '''
    Parse the configuration file.

    Args:
        args: command line arguments.

    Return:
        config: configuration dictionary.
    '''
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    return config

def launch_tensorboard(logdir, host='localhost', port='6006'):
    '''
    Launch TensorBoard.

    Args:
        logdir: directory of the saved results.
        host: host address.
        port: port number.

    Return:

    '''
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', logdir, '--host', host, '--port', port])
    url = tb.launch()

def env_creator(env_config):
    '''
    Create Gymnasium-like environment.

    Args:
        env_config: configuration passed to the environment.

    Return:
        env: environment object.
    '''
    return CarlaEnv(env_config)

def run(args):
    '''
    Run Ray Tuner.

    Args:
        args: command line arguments.

    Return:

    '''
    try:
        ray.init(num_cpus=12, num_gpus=2)

        register_env('carla', env_creator)

        os.system('nvidia-smi')

        if not args.restore:
            tuner = tune.Tuner(
                'SAC',
                run_config=air.RunConfig(
                    name=args.name,
                    storage_path=args.directory,
                    checkpoint_config=air.CheckpointConfig(
                        num_to_keep=2,
                        checkpoint_frequency=1,
                        checkpoint_at_end=True
                    ),
                    stop={'training_iteration': 8192},
                    verbose=2
                ),
                param_space=args.config,
            )
        else:
            tuner = tune.Tuner.restore(os.path.join(args.directory, args.name), 'SAC', resume_errored=True)

        result = tuner.fit().get_best_result()

        print(result)

    except Exception as e:
        print(e)
    finally:
        ray.shutdown()
        time.sleep(10.0)

def main():
    args.config = parse_config(args)

    if args.tb:
        launch_tensorboard(logdir=os.path.join(args.directory, args.name))

    run(args)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        ray.shutdown()
    finally:
        print('Done.')

along with the following config file:

framework: 'torch'

env: 'carla'
disable_env_checking: True

num_workers: 1
num_gpus: 1
num_cpus_per_worker: 8
num_gpus_per_worker: 1

train_batch_size: 256

log_level: 'DEBUG'

ignore_worker_failures: True
restart_failed_sub_environments: False

checkpoint_at_end: True
export_native_model_files: True

keep_per_episode_custom_metrics: True

q_model_config:
  fcnet_hiddens: [256, 256]
  dim: 200
  conv_filters: [
    [16, [3, 3], 2],
    [32, [3, 3], 2],
    [32, [3, 3], 2],
    [64, [3, 3], 2],
    [64, [3, 3], 2],
    [128, [3, 3], 2]
  ]
  post_fcnet_hiddens: [256]

policy_model_config:
  fcnet_hiddens: [256, 256]
  dim: 200
  conv_filters: [
    [16, [3, 3], 2],
    [32, [3, 3], 2],
    [32, [3, 3], 2],
    [64, [3, 3], 2],
    [64, [3, 3], 2],
    [128, [3, 3], 2]
  ]
  post_fcnet_hiddens: [256]

CarlaEnv is a rather complex Gym environment, but basically it has the following action and observation spaces:

Box(low=-1.0, high=1.0)
Tuple((Box(low=-1.001, high=1.001, shape=(200, 200, 5)),
       Box(low=-1.001, high=1.001, shape=(12, 3))))

Issue Severity

Medium: It is a significant difficulty but I can work around it.

GoodarzMehr commented 7 months ago

In my debugging I noticed that the native model saved to a checkpoint in my case is around ~116 MB, but when I use Algorithm.from_checkpoint() to reinitialize the model (i.e. load the weights and state parameters from the state file) and then immediately use Algorithm.save_checkpoint() without any training, the native model saved to the new checkpoint is ~66 MB. The tower_stats or _last_outputattributes of the new model were not the same as the original one, but even after modifyingtorch_policy.py` to save those to the state file and load them from it, the model was smaller (~88 MB) compared to the original one, indicating some information is still missing. The only solution I found that could solve the problem was doing this:

ray.init(num_cpus=12, num_gpus=2)

register_env('carla', env_creator)

os.system('nvidia-smi')

if not os.path.exists(os.path.join(args.directory, args.name)):
    os.mkdir(os.path.join(args.directory, args.name))

if not args.restore:
    sac_config = SACConfig().framework(**args.config['framework']) \
        .environment(**args.config['environment']) \
        .callbacks(**args.config['callbacks']) \
        .rollouts(**args.config['rollouts']) \
        .fault_tolerance(**args.config['fault_tolerance']) \
        .resources(**args.config['resources']) \
        .debugging(**args.config['debugging']) \
        .checkpointing(**args.config['checkpointing']) \
        .reporting(**args.config['reporting']) \
        .training(**args.config['training'])

    sac_algo = sac_config.build()
else:
    sac_algo = Algorithm.from_checkpoint(os.path.join(args.directory, args.name))

    model = torch.load(os.path.join(args.directory, args.name, 'policies', 'default_policy', 'model', 'model.pt'))

    sac_algo.get_policy().model = copy.deepcopy(model)
    sac_algo.get_policy().target_model = copy.deepcopy(model)

    gpu_ids = list(range(torch.cuda.device_count()))

    devices = [
        torch.device("cuda:{}".format(i))
        for i, id_ in enumerate(gpu_ids)
        if i < args.config['resources']['num_gpus']
    ]

    sac_algo.get_policy().model_gpu_towers = []

    for i, _ in enumerate(gpu_ids):
        model_copy = copy.deepcopy(model)
        sac_algo.get_policy().model_gpu_towers.append(model_copy.to(devices[i]))

    sac_algo.get_policy().model_gpu_towers[0] = sac_algo.get_policy().model

    sac_algo.get_policy().target_models = {
        m: copy.deepcopy(sac_algo.get_policy().target_model).to(devices[i])
        for i, m in enumerate(sac_algo.get_policy().model_gpu_towers)
    }

    sac_algo.get_policy()._state_inputs = sac_algo.get_policy().model.get_initial_state()

    sac_algo.get_policy()._is_recurrent = len(sac_algo.get_policy()._state_inputs) > 0

    sac_algo.get_policy()._update_model_view_requirements_from_init_state()

    sac_algo.get_policy().view_requirements.update(sac_algo.get_policy().model.view_requirements)

    sac_algo.get_policy().unwrapped_model = model

    sac_algo.get_policy()._optimizers = force_list(sac_algo.get_policy().optimizer())

    sac_algo.get_policy().multi_gpu_param_groups = []

    main_params = {p: i for i, p in enumerate(sac_algo.get_policy().model.parameters())}

    for o in sac_algo.get_policy()._optimizers:
        param_indices = []

        for pg_idx, pg in enumerate(o.param_groups):
            for p in pg["params"]:
                param_indices.append(main_params[p])

        sac_algo.get_policy().multi_gpu_param_groups.append(set(param_indices))

for i in range(32768):
    print(f'Iteration: {i}')

    sac_algo.train()

    if i % 8 == 0:
        sac_algo.save_checkpoint(os.path.join(args.directory, args.name))

I’m essentially plugging the old model back in and re-initializing the model-based stuff (like the optimizers).