araffin / learning-to-drive-in-5-minutes

Implementation of reinforcement learning approach to make a car learn to drive smoothly in minutes
https://towardsdatascience.com/learning-to-drive-smoothly-in-minutes-450a7cdb35f4
MIT License
284 stars 88 forks source link

[Question] How can I get checkpoint while training? #27

Open 17011813 opened 4 years ago

17011813 commented 4 years ago

Hello, I tried to make checkpoint to save the model. So I tried to open session with import tensorflow and saver = tf.train.Saver(). But I've got an error 'ValueError' : No variables to save. I think this is because in train.py code there is no tensorflow session, but I made session on purpose to save checkpoint. Also I wanna make pb file too. In pretrained agent that you upload, there is pkl file and parameters, but I want to get checkpoint file and pb file. This is the code that I tried. I run python train.py --algo sac -n 20000 -vae logs/vae-level-0-dim-32.pkl with this code, But I can't get checkpoint file,,, How can I get checkpoint file with your code?

`# Code adapted from https://github.com/araffin/rl-baselines-zoo

Author: Antonin Raffin

import argparse import os import time import warnings from collections import OrderedDict from pprint import pprint ​ import tensorflow as tf from tensorflow.core.protobuf import saver_pb2 ​ sess = tf.Session() LOGDIR = './save' save_file='./save/model.ckpt' saver = tf.train.Saver() ​ if os.path.isfile(save_file+".meta"): saver=tf.train.Saver() saver.restore(sess,save_file) print("-----------------run with saved model----------------------------------") ​ else: sess.run(tf.global_variables_initializer()) print("------------------initialize network------------------------------------") ​

Remove warnings

warnings.filterwarnings("ignore", category=FutureWarning, module='tensorflow') warnings.filterwarnings("ignore", category=UserWarning, module='gym') ​ import numpy as np import yaml from stable_baselines.common import set_global_seeds from stable_baselines.common.vec_env import VecFrameStack, VecNormalize, DummyVecEnv from stable_baselines.ddpg import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise

from stable_baselines.ppo2.ppo2 import constfn

​ from config import MIN_THROTTLE, MAX_THROTTLE, FRAME_SKIP,\ MAX_CTE_ERROR, SIM_PARAMS, N_COMMAND_HISTORY, Z_SIZE, BASE_ENV, ENV_ID, MAX_STEERING_DIFF from utils.utils import make_env, ALGOS, linear_schedule, get_latest_run_id, load_vae, create_callback from teleop.teleop_client import TeleopEnv ​ parser = argparse.ArgumentParser() parser.add_argument('-tb', '--tensorboard-log', help='Tensorboard log dir', default='', type=str) parser.add_argument('-i', '--trained-agent', help='Path to a pretrained agent to continue training', default='', type=str) parser.add_argument('--algo', help='RL Algorithm', default='sac', type=str, required=False, choices=list(ALGOS.keys())) parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=-1, type=int) parser.add_argument('--log-interval', help='Override log interval (default: -1, no change)', default=-1, type=int) parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='logs') parser.add_argument('-vae', '--vae-path', help='Path to saved VAE', type=str, default='') parser.add_argument('--save-vae', action='store_true', default=False, help='Save VAE') parser.add_argument('--seed', help='Random generator seed', type=int, default=0) parser.add_argument('--random-features', action='store_true', default=False, help='Use random features') parser.add_argument('--teleop', action='store_true', default=False, help='Use teleoperation for training') args = parser.parse_args() ​ set_global_seeds(args.seed) ​ if args.trained_agent != "": assert args.trained_agent.endswith('.pkl') and os.path.isfile(args.trained_agent), \ "The trained_agent must be a valid path to a .pkl file" ​ tensorboard_log = None if args.tensorboard_log == '' else args.tensorboard_log + '/' + ENV_ID ​ print("=" 10, ENV_ID, args.algo, "=" 10) ​ vae = None if args.vae_path != '': print("Loading VAE ...") vae = load_vae(args.vae_path) elif args.random_features: print("Randomly initialized VAE") vae = load_vae(z_size=Z_SIZE)

Save network

args.save_vae = True else: print("Learning from pixels...") ​

Load hyperparameters from yaml file

with open('hyperparams/{}.yml'.format(args.algo), 'r') as f: hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader)[BASE_ENV] ​ ​ hyperparams['seed'] = args.seed

Sort hyperparams that will be saved

saved_hyperparams = OrderedDict([(key, hyperparams[key]) for key in sorted(hyperparams.keys())])

save vae path

saved_hyperparams['vae_path'] = args.vae_path if vae is not None: saved_hyperparams['z_size'] = vae.z_size ​

Save simulation params

for key in SIM_PARAMS: saved_hyperparams[key] = eval(key) pprint(saved_hyperparams) ​

Compute and create log path

log_path = os.path.join(args.log_folder, args.algo) print('-------------------------------------------------------------------------------------') save_path = os.path.join(logpath, "{}{}".format(ENV_ID, get_latest_run_id(log_path, ENV_ID) + 1)) print(os.path.abspath(save_path)) ​ params_path = os.path.join(save_path, ENV_ID) os.makedirs(params_path, existok=True) ​ def constfn(val): def f(): return val return f ​

Create learning rate schedules for ppo2 and sac

if args.algo in ["ppo2", "sac"]: for key in ['learning_rate', 'cliprange']: if key not in hyperparams: continue if isinstance(hyperparams[key], str): schedule, initialvalue = hyperparams[key].split('') initial_value = float(initial_value) hyperparams[key] = linear_schedule(initial_value) elif isinstance(hyperparams[key], float): hyperparams[key] = constfn(hyperparams[key]) else: raise ValueError('Invalid valid for {}: {}'.format(key, hyperparams[key])) ​

Should we overwrite the number of timesteps?

if args.n_timesteps > 0: n_timesteps = args.n_timesteps else: n_timesteps = int(hyperparams['n_timesteps']) del hyperparams['n_timesteps'] ​ normalize = False normalize_kwargs = {} if 'normalize' in hyperparams.keys(): normalize = hyperparams['normalize'] if isinstance(normalize, str): normalize_kwargs = eval(normalize) normalize = True del hyperparams['normalize'] ​ if not args.teleop: env = DummyVecEnv([make_env(args.seed, vae=vae, teleop=args.teleop)]) else: env = make_env(args.seed, vae=vae, teleop=args.teleop, n_stack=hyperparams.get('frame_stack', 1))() print('**') if normalize: if hyperparams.get('normalize', False) and args.algo in ['ddpg']: print("WARNING: normalization not supported yet for DDPG") else: print("Normalizing input and return") env = VecNormalize(env, **normalize_kwargs) ​

Optional Frame-stacking

n_stack = 1 if hyperparams.get('frame_stack', False): n_stack = hyperparams['frame_stack'] if not args.teleop: env = VecFrameStack(env, n_stack) print("Stacking {} frames".format(n_stack)) del hyperparams['frame_stack'] ​

Parse noise string for DDPG

if args.algo == 'ddpg' and hyperparams.get('noise_type') is not None: noise_type = hyperparams['noise_type'].strip() noise_std = hyperparams['noise_std'] n_actions = env.action_space.shape[0] if 'adaptive-param' in noise_type: hyperparams['param_noise'] = AdaptiveParamNoiseSpec(initial_stddev=noise_std, desired_action_stddev=noise_std) elif 'normal' in noise_type: hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(n_actions), sigma=noise_std np.ones(n_actions)) elif 'ornstein-uhlenbeck' in noise_type: hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=noise_std np.ones(n_actions)) else: raise RuntimeError('Unknown noise type "{}"'.format(noise_type)) print("Applying {} noise with std {}".format(noise_type, noise_std)) del hyperparams['noise_type'] del hyperparams['noise_std'] ​ if args.trained_agent.endswith('.pkl') and os.path.isfile(args.trained_agent):

Continue training

print("Loading pretrained agent")

Policy should not be changed

del hyperparams['policy'] ​ model = ALGOS[args.algo].load(args.trained_agent, env=env, tensorboard_log=tensorboard_log, verbose=1, **hyperparams) ​ exp_folder = args.trained_agent.split('.pkl')[0] if normalize: print("Loading saved running average") env.load_running_average(exp_folder) else:

Train an agent from scratch

model = ALGOS[args.algo](env=env, tensorboard_log=tensorboard_log, verbose=1, **hyperparams) ​

Teleoperation mode:

we don't wrap the environment with a monitor or in a vecenv

if args.teleop: assert args.algo == "sac", "Teleoperation mode is not yet implemented for {}".format(args.algo) env = TeleopEnv(env, is_training=True) model.set_env(env) env.model = model ​ kwargs = {} if args.log_interval > -1: kwargs = {'log_interval': args.log_interval} ​ if args.algo == 'sac': kwargs.update({'callback': create_callback(args.algo, os.path.join(save_path, ENV_ID + "_best"), verbose=1)}) ​ ​ ​ model.learn(n_timesteps, **kwargs) ​ if args.teleop: env.wait() env.exit() time.sleep(0.5) else:

Close the connection properly

env.reset() if isinstance(env, VecFrameStack): env = env.venv

HACK to bypass Monitor wrapper

env.envs[0].env.exit_scene() ​

Save trained model

model.save(os.path.join(save_path, ENV_ID), cloudpickle=True) ​

Save hyperparams

with open(os.path.join(params_path, 'config.yml'), 'w') as f: yaml.dump(saved_hyperparams, f) ​ if args.save_vae and vae is not None: print("Saving VAE") vae.save(os.path.join(params_path, 'vae')) ​ if normalize:

Unwrap

if isinstance(env, VecFrameStack): env = env.venv

Important: save the running average, for testing the agent we need that normalization

env.save_running_average(params_path) ​ ​ if not os.path.exists(LOGDIR): os.makedirs(LOGDIR) checkpoint_path = os.path.join(LOGDIR, "model.ckpt") filename = saver.save(sess, checkpoint_path) print("Model saved in file: %s" % filename) ​ tf.io.write_graph(sess.graph_def, '.', 'graph.pb', as_text=False)​`

I'm really appreciate with your this project!!! Thanks a lot 👍

araffin commented 4 years ago

Hello, The answer is the documentation of stable-baselines (cf saving/loading in stable-baselines.readthedocs.io/), you need to use a callback to make checkpoints.

We have already a callback saving the best model: https://github.com/araffin/learning-to-drive-in-5-minutes/blob/master/train.py#L195

Note that the code it out of data and does not match the latest SB version. I plan to open source an updated version in the coming month but I won't update this repo.

PS: please use markdown to format your code next time

17011813 commented 4 years ago

Thanks for your response!! But I mean I wanna make check point file something like meta.ckpt or model.ckpt file. But with your original callback code, I could only get pkl and parameter text files,,, I mean how can I get checkpoint file like .ckpt format??

Also I wanna get .pb format file after run train.py That's why I added this code at the end of train.py code.

if not os.path.exists(LOGDIR):
os.makedirs(LOGDIR)
checkpoint_path = os.path.join(LOGDIR, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print("Model saved in file: %s" % filename)
​
tf.io.write_graph(sess.graph_def, '.', 'graph.pb', as_text=False)​

Also why I wanna get .pb and .ckpt file is I want to freeze this train.py with .pb and .ckpt file :) Or do you know how to make freeze file with your .pkl and best.zip files?? If you know it then just let me know <3

Always thanks for your kindness <3