Open GoodarzMehr opened 8 months ago
In my debugging I noticed that the native model saved to a checkpoint in my case is around ~116 MB, but when I use Algorithm.from_checkpoint()
to reinitialize the model (i.e. load the weights and state parameters from the state file) and then immediately use Algorithm.save_checkpoint()
without any training, the native model saved to the new checkpoint is ~66 MB. The tower_stats
or _last_outputattributes of the new model were not the same as the original one, but even after modifying
torch_policy.py` to save those to the state file and load them from it, the model was smaller (~88 MB) compared to the original one, indicating some information is still missing. The only solution I found that could solve the problem was doing this:
ray.init(num_cpus=12, num_gpus=2)
register_env('carla', env_creator)
os.system('nvidia-smi')
if not os.path.exists(os.path.join(args.directory, args.name)):
os.mkdir(os.path.join(args.directory, args.name))
if not args.restore:
sac_config = SACConfig().framework(**args.config['framework']) \
.environment(**args.config['environment']) \
.callbacks(**args.config['callbacks']) \
.rollouts(**args.config['rollouts']) \
.fault_tolerance(**args.config['fault_tolerance']) \
.resources(**args.config['resources']) \
.debugging(**args.config['debugging']) \
.checkpointing(**args.config['checkpointing']) \
.reporting(**args.config['reporting']) \
.training(**args.config['training'])
sac_algo = sac_config.build()
else:
sac_algo = Algorithm.from_checkpoint(os.path.join(args.directory, args.name))
model = torch.load(os.path.join(args.directory, args.name, 'policies', 'default_policy', 'model', 'model.pt'))
sac_algo.get_policy().model = copy.deepcopy(model)
sac_algo.get_policy().target_model = copy.deepcopy(model)
gpu_ids = list(range(torch.cuda.device_count()))
devices = [
torch.device("cuda:{}".format(i))
for i, id_ in enumerate(gpu_ids)
if i < args.config['resources']['num_gpus']
]
sac_algo.get_policy().model_gpu_towers = []
for i, _ in enumerate(gpu_ids):
model_copy = copy.deepcopy(model)
sac_algo.get_policy().model_gpu_towers.append(model_copy.to(devices[i]))
sac_algo.get_policy().model_gpu_towers[0] = sac_algo.get_policy().model
sac_algo.get_policy().target_models = {
m: copy.deepcopy(sac_algo.get_policy().target_model).to(devices[i])
for i, m in enumerate(sac_algo.get_policy().model_gpu_towers)
}
sac_algo.get_policy()._state_inputs = sac_algo.get_policy().model.get_initial_state()
sac_algo.get_policy()._is_recurrent = len(sac_algo.get_policy()._state_inputs) > 0
sac_algo.get_policy()._update_model_view_requirements_from_init_state()
sac_algo.get_policy().view_requirements.update(sac_algo.get_policy().model.view_requirements)
sac_algo.get_policy().unwrapped_model = model
sac_algo.get_policy()._optimizers = force_list(sac_algo.get_policy().optimizer())
sac_algo.get_policy().multi_gpu_param_groups = []
main_params = {p: i for i, p in enumerate(sac_algo.get_policy().model.parameters())}
for o in sac_algo.get_policy()._optimizers:
param_indices = []
for pg_idx, pg in enumerate(o.param_groups):
for p in pg["params"]:
param_indices.append(main_params[p])
sac_algo.get_policy().multi_gpu_param_groups.append(set(param_indices))
for i in range(32768):
print(f'Iteration: {i}')
sac_algo.train()
if i % 8 == 0:
sac_algo.save_checkpoint(os.path.join(args.directory, args.name))
I’m essentially plugging the old model back in and re-initializing the model-based stuff (like the optimizers).
What happened + What you expected to happen
I am using RLlib's SAC with a multi-agent environment that crashes from time to time due to memory issues. Using
Tuner.restore()
(or alternativelyAlgorithm.save_checkpoint()
andAlgorithm.from_checkpoint()
) does not restore the training, instead re-initializing it, as can be seen in the image below.See here for further information.
Versions / Dependencies
Ubuntu 22.04 Python 3.8.10 Ray 2.9.2 Torch 1.10.1 (cu113) CUDA 11.3
Reproduction script
I am using the following script:
along with the following config file:
CarlaEnv
is a rather complex Gym environment, but basically it has the following action and observation spaces:Issue Severity
Medium: It is a significant difficulty but I can work around it.