Closed haloted closed 5 years ago
I haven't tried this but stable-baselines is using TensorFlow and I think you can just load the whole graph from a checkpoint.
I modified the code in quadrotor position tracking.py to the following: if mode == 'pretrained':
# load model first
weight_path = args.weight
if weight_path == "":
print("Can't find trained weight, please provide a trained weight with --weight switch\n")
else:
print("Loaded weight from {}\n".format(weight_path))
model = PPO2.load(weight_path,env=env)
# tensorboard
# Make sure that your chrome browser is already on.
TensorboardLauncher(saver.data_dir + '/PPO2_1')
# PPO run
# Originally the total timestep is 500000000
# 10 zeros for nupdates to be 4000
# 1000000000 is 2000 iterations and so
# 2000000000 is 4000 iterations.
model.learn(
total_timesteps=100000000000,
eval_every_n=50,
log_dir=saver.data_dir,
record_video=cfg['record_video']
)
model.save(saver.data_dir)
# Need this line if you want to keep tensorflow alive after training
input("Press Enter to exit... Tensorboard will be closed after exit\n")
And I got the following error:
Traceback (most recent call last):
File "quadrotor_position_tracking_pretrained.py", line 122, in
I haven't tried retraining a network. You should look at stable-baselines manual to do so. But I think if you load a network, it is set to the test mode and you cannot train it anymore.
if i'm not mistacen it actually might be a bug. The MlpPolicy inherit from BasePolicy, which is located at raisimgym/archi/policy.py
At current version of stable-baselines there is class variable:
recurrent = False
https://github.com/hill-a/stable-baselines/blob/3105f30f53ae20e4a9b9bad166ebec20cdefa2dc/stable_baselines/common/policies.py#L108
But in raisim local version missed it. https://github.com/leggedrobotics/raisimGym/blob/f9356d379ebfc13616553bcc60c640e06513c6fe/raisim_gym/archi/policies.py#L107
Which in my case lead to assert fault.
Maybe this is due to the version change of stable-baselines. I added recurrent = false
in the file. let me know if you still have that issue
Is there an option to load a pretrained model and continue training?