hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

[question] How do I load a tensorflow ckpt? #1147

Open Syzygianinfern0 opened 2 years ago

Syzygianinfern0 commented 2 years ago

I am trying to load a pre-trained model from some old code using this framework and my familiarity with tensorflow is very limited. I've tried multiple things to load the model but I am unable to find the right way 🤯

Here is how the model is created and saved. I just want to load back the weights after saving for evaluation.

Below, I've shown a representative of how the model is created then stored.

import tensorflow as tf
from stable_baselines import PPO1
from stable_baselines.common.policies import FeedForwardPolicy

training_sess = None

class MyMlpPolicy(FeedForwardPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
        super(MyMlpPolicy, self).__init__(
            sess,
            ob_space,
            ac_space,
            n_env,
            n_steps,
            n_batch,
            reuse,
            net_arch=[{"pi": [32, 16], "vf": [32, 16]}],
            feature_extraction="mlp",
            **_kwargs
        )
        global training_sess
        training_sess = sess

model = PPO1(MyMlpPolicy, env)

# This is how the model is saved
with model.graph.as_default():
    saver = tf.train.Saver()
    saver.save(training_sess, "./model_0.ckpt")

# The above step produces 4 types of files
# 1. checkpoint
# 2. model_0.ckpt.data-00000-of-00001
# 3. model_0.ckpt.index
# 4. model_0.ckpt.meta
Miffyli commented 2 years ago

Please fill in the issue template. If you only want to save the full agent you do not need to do any TF stuff, only use save and load functions (see examples in docs). We can not offer custom tech support for saving/loading in a custom way like this.

araffin commented 2 years ago

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

Syzygianinfern0 commented 2 years ago

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

Miffyli commented 2 years ago

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

In that case you should look at the set_parameters function in the SB3 documentation :).

You can close this issue if your question has been answered.