DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9k stars 1.69k forks source link

How to use normalizations in inference? #946

Closed Trolldemorted closed 2 years ago

Trolldemorted commented 2 years ago

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

📚 Documentation

A clear and concise description of what should be improved in the documentation.

Checklist

We have used your VecNormalize wrapper and like it a lot, but we are wondering how to use the final normalization in setups that only do inference.

We are exporting them as noted in the documentation:

stats_path = os.path.join("archive/" + log_name_, "vec_normalize.pkl")
env.save(stats_path)  

And load them accordingly:

env = DummyVecEnv([OurVeryFancyEnv])
env = VecNormalize.load(stats_path, env)
env.training = False
env.norm_reward = False
model_test = PPO.load("archive/" + log_name_ + "/best_model.zip",env)     

Is it truly neccessary to wrap our env in a vectorized environment to load/apply the normalizations for inference? Is the output documented somewhere, can we apply them "manually" on the observations?

araffin commented 2 years ago

Hello,

Is it truly neccessary to wrap our env in a vectorized environment to load/apply the normalizations for inference?

If you want to make things quickly, yes, as VecNormalize is a VecEnvWrapper and therefore requires a VecEnv.

But, in practice, you don't need it if you do the normalization yourself, for that, take a look at the code ;)

https://github.com/DLR-RM/stable-baselines3/blob/7a0163712805756d57f103844acbafb1829cda99/stable_baselines3/common/vec_env/vec_normalize.py#L160

https://github.com/DLR-RM/stable-baselines3/blob/7a0163712805756d57f103844acbafb1829cda99/stable_baselines3/common/vec_env/vec_normalize.py#L181-L188

jank324 commented 2 years ago

I've had the same issue as @Trolldemorted and implemented a somewhat hacky utility wrapper to solve it.

class NotVecNormalize(gym.Wrapper):
    """
    Normal Gym wrapper that replicates the functionality of Stable Baselines3's VecNormalize wrapper
    for non VecEnvs (i.e. `gym.Env`) in production.
    """

    def __init__(self, env, path):
        super().__init__(env)

        with open(path, "rb") as file_handler:
            self.vec_normalize = pickle.load(file_handler)

    def reset(self):
        observation = self.env.reset()
        return self.vec_normalize.normalize_obs(observation)

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        observation = self.vec_normalize.normalize_obs(observation)
        reward = self.vec_normalize.normalize_reward(reward)
        return observation, reward, done, info

I implemented this wrapper, because VecEnv automatically resets environments at the end of an episode. While this behaviour is very nice in training, in production it became a bit of a nuisance. (I am using RL for optimisation and want retain the state achieved by the end of an episode.) Applying the normalisation to a normal gym.Env solves this problem nicely.

I'm not sure if my situation is niche or if there is a better way to disable the automatic reset in VecEnvs, but maybe it would make sense to include a clean version of this utility wrapper in Stable Baselines.

Note: I'm not sure how all of this will be affected by the changes to Gym beyond v0.21 (especially VectorEnv).

Trolldemorted commented 2 years ago

Thanks @araffin and @jank324 !

We had exactly the same issue as @jank324 - we trained on a VecEnv, but the implicit resets were really annoying in our inference setups. Thanks for sharing :)

araffin commented 2 years ago

somewhat hacky utility wrapper to solve it.

@jank324

Thanks for sharing =) Not really hacky in fact, you are using public methods of the saved normalized wrapper. Btw, you don't need to normalize the reward at inference time.

I am using RL for optimisation and want retain the state achieved by the end of an episode we trained on a VecEnv, but the implicit resets were really annoying in our inference setups.

Is the auto-reset annoying for both of you because you cannot get the terminal observation? or just not controlling the reset is a problem?

The state achieved at the end of an episode is actually returned by the env in the info dict with the terminal_observation key:

https://github.com/DLR-RM/stable-baselines3/blob/ed308a71be24036744b5ad4af61b083e4fbdf83c/stable_baselines3/common/vec_env/dummy_vec_env.py#L47-L48

jank324 commented 2 years ago

The terminal observation isn't the problem in my case.

My environment is actually interfacing with a physical machine in the real world, specifically a particle accelerator. The agent's task is to change actuator settings until the electron beam has certain properties or a maximum number of steps have been taken. I'm using done to communicate that either of these conditions has been met. At the start of the episode all actuators are cleanly brought to settings known to be near the good settings in reset. Doing this during training (in simulation) massively reduces training times, but results in agents that expect all actuators to start on these good settings, so it's necessary for this logic to be present in reset in production as well.

Now, if I use a VecEnv, the automatic call to reset resets all the actuators, but what I want is for them to remain on the setting found by the agent. On top of that, I only have one accelerator, so the vectorisation is unnecessary at that point.

araffin commented 2 years ago

Now, if I use a VecEnv, the automatic call to reset resets all the actuators, but what I want is for them to remain on the setting found by the agent.

Sounds like you need to deactivate termination completely at test time (except the first one), or at least deactivate the reset of the actuator after the first episode (should be easy to do) (or deactivate the timeout if that the termination condition you don't need at test time).