Open 17011813 opened 4 years ago
Hello, The answer is the documentation of stable-baselines (cf saving/loading in stable-baselines.readthedocs.io/), you need to use a callback to make checkpoints.
We have already a callback saving the best model: https://github.com/araffin/learning-to-drive-in-5-minutes/blob/master/train.py#L195
Note that the code it out of data and does not match the latest SB version. I plan to open source an updated version in the coming month but I won't update this repo.
PS: please use markdown to format your code next time
Thanks for your response!! But I mean I wanna make check point file something like meta.ckpt or model.ckpt file. But with your original callback code, I could only get pkl and parameter text files,,, I mean how can I get checkpoint file like .ckpt format??
Also I wanna get .pb format file after run train.py That's why I added this code at the end of train.py code.
if not os.path.exists(LOGDIR):
os.makedirs(LOGDIR)
checkpoint_path = os.path.join(LOGDIR, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print("Model saved in file: %s" % filename)
tf.io.write_graph(sess.graph_def, '.', 'graph.pb', as_text=False)
Also why I wanna get .pb and .ckpt file is I want to freeze this train.py with .pb and .ckpt file :) Or do you know how to make freeze file with your .pkl and best.zip files?? If you know it then just let me know <3
Always thanks for your kindness <3
Hello, I tried to make checkpoint to save the model. So I tried to open session with import tensorflow and saver = tf.train.Saver(). But I've got an error 'ValueError' : No variables to save. I think this is because in train.py code there is no tensorflow session, but I made session on purpose to save checkpoint. Also I wanna make pb file too. In pretrained agent that you upload, there is pkl file and parameters, but I want to get checkpoint file and pb file. This is the code that I tried. I run python train.py --algo sac -n 20000 -vae logs/vae-level-0-dim-32.pkl with this code, But I can't get checkpoint file,,, How can I get checkpoint file with your code?
`# Code adapted from https://github.com/araffin/rl-baselines-zoo
Author: Antonin Raffin
import argparse import os import time import warnings from collections import OrderedDict from pprint import pprint import tensorflow as tf from tensorflow.core.protobuf import saver_pb2 sess = tf.Session() LOGDIR = './save' save_file='./save/model.ckpt' saver = tf.train.Saver() if os.path.isfile(save_file+".meta"): saver=tf.train.Saver() saver.restore(sess,save_file) print("-----------------run with saved model----------------------------------") else: sess.run(tf.global_variables_initializer()) print("------------------initialize network------------------------------------")
Remove warnings
warnings.filterwarnings("ignore", category=FutureWarning, module='tensorflow') warnings.filterwarnings("ignore", category=UserWarning, module='gym') import numpy as np import yaml from stable_baselines.common import set_global_seeds from stable_baselines.common.vec_env import VecFrameStack, VecNormalize, DummyVecEnv from stable_baselines.ddpg import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines.ppo2.ppo2 import constfn
from config import MIN_THROTTLE, MAX_THROTTLE, FRAME_SKIP,\ MAX_CTE_ERROR, SIM_PARAMS, N_COMMAND_HISTORY, Z_SIZE, BASE_ENV, ENV_ID, MAX_STEERING_DIFF from utils.utils import make_env, ALGOS, linear_schedule, get_latest_run_id, load_vae, create_callback from teleop.teleop_client import TeleopEnv parser = argparse.ArgumentParser() parser.add_argument('-tb', '--tensorboard-log', help='Tensorboard log dir', default='', type=str) parser.add_argument('-i', '--trained-agent', help='Path to a pretrained agent to continue training', default='', type=str) parser.add_argument('--algo', help='RL Algorithm', default='sac', type=str, required=False, choices=list(ALGOS.keys())) parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=-1, type=int) parser.add_argument('--log-interval', help='Override log interval (default: -1, no change)', default=-1, type=int) parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='logs') parser.add_argument('-vae', '--vae-path', help='Path to saved VAE', type=str, default='') parser.add_argument('--save-vae', action='store_true', default=False, help='Save VAE') parser.add_argument('--seed', help='Random generator seed', type=int, default=0) parser.add_argument('--random-features', action='store_true', default=False, help='Use random features') parser.add_argument('--teleop', action='store_true', default=False, help='Use teleoperation for training') args = parser.parse_args() set_global_seeds(args.seed) if args.trained_agent != "": assert args.trained_agent.endswith('.pkl') and os.path.isfile(args.trained_agent), \ "The trained_agent must be a valid path to a .pkl file" tensorboard_log = None if args.tensorboard_log == '' else args.tensorboard_log + '/' + ENV_ID print("=" 10, ENV_ID, args.algo, "=" 10) vae = None if args.vae_path != '': print("Loading VAE ...") vae = load_vae(args.vae_path) elif args.random_features: print("Randomly initialized VAE") vae = load_vae(z_size=Z_SIZE)
Save network
args.save_vae = True else: print("Learning from pixels...")
Load hyperparameters from yaml file
with open('hyperparams/{}.yml'.format(args.algo), 'r') as f: hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader)[BASE_ENV] hyperparams['seed'] = args.seed
Sort hyperparams that will be saved
saved_hyperparams = OrderedDict([(key, hyperparams[key]) for key in sorted(hyperparams.keys())])
save vae path
saved_hyperparams['vae_path'] = args.vae_path if vae is not None: saved_hyperparams['z_size'] = vae.z_size
Save simulation params
for key in SIM_PARAMS: saved_hyperparams[key] = eval(key) pprint(saved_hyperparams)
Compute and create log path
log_path = os.path.join(args.log_folder, args.algo) print('-------------------------------------------------------------------------------------') save_path = os.path.join(logpath, "{}{}".format(ENV_ID, get_latest_run_id(log_path, ENV_ID) + 1)) print(os.path.abspath(save_path)) params_path = os.path.join(save_path, ENV_ID) os.makedirs(params_path, existok=True) def constfn(val): def f(): return val return f
Create learning rate schedules for ppo2 and sac
if args.algo in ["ppo2", "sac"]: for key in ['learning_rate', 'cliprange']: if key not in hyperparams: continue if isinstance(hyperparams[key], str): schedule, initialvalue = hyperparams[key].split('') initial_value = float(initial_value) hyperparams[key] = linear_schedule(initial_value) elif isinstance(hyperparams[key], float): hyperparams[key] = constfn(hyperparams[key]) else: raise ValueError('Invalid valid for {}: {}'.format(key, hyperparams[key]))
Should we overwrite the number of timesteps?
if args.n_timesteps > 0: n_timesteps = args.n_timesteps else: n_timesteps = int(hyperparams['n_timesteps']) del hyperparams['n_timesteps'] normalize = False normalize_kwargs = {} if 'normalize' in hyperparams.keys(): normalize = hyperparams['normalize'] if isinstance(normalize, str): normalize_kwargs = eval(normalize) normalize = True del hyperparams['normalize'] if not args.teleop: env = DummyVecEnv([make_env(args.seed, vae=vae, teleop=args.teleop)]) else: env = make_env(args.seed, vae=vae, teleop=args.teleop, n_stack=hyperparams.get('frame_stack', 1))() print('**') if normalize: if hyperparams.get('normalize', False) and args.algo in ['ddpg']: print("WARNING: normalization not supported yet for DDPG") else: print("Normalizing input and return") env = VecNormalize(env, **normalize_kwargs)
Optional Frame-stacking
n_stack = 1 if hyperparams.get('frame_stack', False): n_stack = hyperparams['frame_stack'] if not args.teleop: env = VecFrameStack(env, n_stack) print("Stacking {} frames".format(n_stack)) del hyperparams['frame_stack']
Parse noise string for DDPG
if args.algo == 'ddpg' and hyperparams.get('noise_type') is not None: noise_type = hyperparams['noise_type'].strip() noise_std = hyperparams['noise_std'] n_actions = env.action_space.shape[0] if 'adaptive-param' in noise_type: hyperparams['param_noise'] = AdaptiveParamNoiseSpec(initial_stddev=noise_std, desired_action_stddev=noise_std) elif 'normal' in noise_type: hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(n_actions), sigma=noise_std np.ones(n_actions)) elif 'ornstein-uhlenbeck' in noise_type: hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=noise_std np.ones(n_actions)) else: raise RuntimeError('Unknown noise type "{}"'.format(noise_type)) print("Applying {} noise with std {}".format(noise_type, noise_std)) del hyperparams['noise_type'] del hyperparams['noise_std'] if args.trained_agent.endswith('.pkl') and os.path.isfile(args.trained_agent):
Continue training
print("Loading pretrained agent")
Policy should not be changed
del hyperparams['policy'] model = ALGOS[args.algo].load(args.trained_agent, env=env, tensorboard_log=tensorboard_log, verbose=1, **hyperparams) exp_folder = args.trained_agent.split('.pkl')[0] if normalize: print("Loading saved running average") env.load_running_average(exp_folder) else:
Train an agent from scratch
model = ALGOS[args.algo](env=env, tensorboard_log=tensorboard_log, verbose=1, **hyperparams)
Teleoperation mode:
we don't wrap the environment with a monitor or in a vecenv
if args.teleop: assert args.algo == "sac", "Teleoperation mode is not yet implemented for {}".format(args.algo) env = TeleopEnv(env, is_training=True) model.set_env(env) env.model = model kwargs = {} if args.log_interval > -1: kwargs = {'log_interval': args.log_interval} if args.algo == 'sac': kwargs.update({'callback': create_callback(args.algo, os.path.join(save_path, ENV_ID + "_best"), verbose=1)}) model.learn(n_timesteps, **kwargs) if args.teleop: env.wait() env.exit() time.sleep(0.5) else:
Close the connection properly
env.reset() if isinstance(env, VecFrameStack): env = env.venv
HACK to bypass Monitor wrapper
env.envs[0].env.exit_scene()
Save trained model
model.save(os.path.join(save_path, ENV_ID), cloudpickle=True)
Save hyperparams
with open(os.path.join(params_path, 'config.yml'), 'w') as f: yaml.dump(saved_hyperparams, f) if args.save_vae and vae is not None: print("Saving VAE") vae.save(os.path.join(params_path, 'vae')) if normalize:
Unwrap
if isinstance(env, VecFrameStack): env = env.venv
Important: save the running average, for testing the agent we need that normalization
env.save_running_average(params_path) if not os.path.exists(LOGDIR): os.makedirs(LOGDIR) checkpoint_path = os.path.join(LOGDIR, "model.ckpt") filename = saver.save(sess, checkpoint_path) print("Model saved in file: %s" % filename) tf.io.write_graph(sess.graph_def, '.', 'graph.pb', as_text=False)`
I'm really appreciate with your this project!!! Thanks a lot 👍