leggedrobotics / raisimGym

Other
140 stars 45 forks source link

Option to load a pretrained model and continue training? #21

Closed haloted closed 5 years ago

haloted commented 5 years ago

Is there an option to load a pretrained model and continue training?

jhwangbo commented 5 years ago

I haven't tried this but stable-baselines is using TensorFlow and I think you can just load the whole graph from a checkpoint.

haloted commented 5 years ago

I modified the code in quadrotor position tracking.py to the following: if mode == 'pretrained':

# load model first 
weight_path = args.weight
if weight_path == "":
    print("Can't find trained weight, please provide a trained weight with --weight switch\n")
else:
    print("Loaded weight from {}\n".format(weight_path))
    model = PPO2.load(weight_path,env=env)

# tensorboard
# Make sure that your chrome browser is already on.
TensorboardLauncher(saver.data_dir + '/PPO2_1')

# PPO run
# Originally the total timestep is 500000000
# 10 zeros for nupdates to be 4000
# 1000000000 is 2000 iterations and so 
# 2000000000 is 4000 iterations. 
model.learn(
    total_timesteps=100000000000,  
    eval_every_n=50, 
    log_dir=saver.data_dir, 
    record_video=cfg['record_video']
)
model.save(saver.data_dir)
    # Need this line if you want to keep tensorflow alive after training
input("Press Enter to exit... Tensorboard will be closed after exit\n")

And I got the following error: Traceback (most recent call last): File "quadrotor_position_tracking_pretrained.py", line 122, in model = PPO2.load(weight_path,env=env) File "/usr/local/lib/python3.5/dist-packages/stable_baselines-2.7.0-py3.5.egg/stable_baselines/common/base_class.py", line 700, in load model.set_env(env) File "/usr/local/lib/python3.5/dist-packages/stable_baselines-2.7.0-py3.5.egg/stable_baselines/common/base_class.py", line 103, in set_env assert not self.policy.recurrent or self.n_envs == env.num_envs, \ AttributeError: type object 'MlpPolicy' has no attribute 'recurrent'

jhwangbo commented 5 years ago

I haven't tried retraining a network. You should look at stable-baselines manual to do so. But I think if you load a network, it is set to the test mode and you cannot train it anymore.

alexpostnikov commented 5 years ago

if i'm not mistacen it actually might be a bug. The MlpPolicy inherit from BasePolicy, which is located at raisimgym/archi/policy.py

At current version of stable-baselines there is class variable: recurrent = False https://github.com/hill-a/stable-baselines/blob/3105f30f53ae20e4a9b9bad166ebec20cdefa2dc/stable_baselines/common/policies.py#L108

But in raisim local version missed it. https://github.com/leggedrobotics/raisimGym/blob/f9356d379ebfc13616553bcc60c640e06513c6fe/raisim_gym/archi/policies.py#L107

Which in my case lead to assert fault.

jhwangbo commented 5 years ago

Maybe this is due to the version change of stable-baselines. I added recurrent = false in the file. let me know if you still have that issue