DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.85k stars 1.68k forks source link

[Question] How can I wrap a non-image observation trained model via an image observation wrapper? #1871

Closed zichunxx closed 6 months ago

zichunxx commented 6 months ago

❓ Question

Hi!

I have a custom-trained model with non-image observation and want to collect some image observations with the trained model.

Below is the code of my brief implementation:


import gymnasium as gym
from stable_baselines3 import SAC, DDPG, PPO
from gymnasium.experimental.wrappers import PixelObservationV0

env = gym.make("CustomEnv", render_mode="rgb_array")

model = SAC.load("trained_model", env=env)

vec_env = model.get_env()

vec_env = PixelObservationV0(vec_env)

obs = vec_env.reset()

for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)

    obs, reward, done, info = vec_env.step(action)

    # print(obs.shape)

    if done:
         obs = vec_env.reset()

vec_env.close()

But I got the TypeError like this

Traceback (most recent call last):
  File "/home/xzc/mambaforge/envs/sb-zoo/lib/python3.9/site-packages/gymnasium/core.py", line 515, in reset
    obs, info = self.env.reset(seed=seed, options=options)
TypeError: reset() got an unexpected keyword argument 'seed'

I've checked the documentation and issue list, but haven't found any examples of pixel wrappers for trained environments. Or maybe I missed something.

Many thanks for considering my request.

Checklist

zichunxx commented 6 months ago

Update:

I found this method vec_env.render() can get the image observation after step(), which doesn't seem to use the relevant wrappers. Is it right?

araffin commented 6 months ago

Hello, you are mixing gym.Env (the wrapper you used) and VecEnv (please have a look at our docs for the main differences). I'm also not sure how you plan to predict on images with a model that was trained on non-image input.

zichunxx commented 6 months ago

Hi! @araffin Thanks for your reply.

Because the CNN policy seems hard to converge. So, I just want to collect images that are predicted by the trained model on non-image input to carry out some training tasks with imitation learning.

vec_env.render() seems to solve my problem without any wrapper, right?

And thanks for your kind reminder. I will check the docs for the differences between these two kinds of wrappers.

araffin commented 6 months ago

vec_env.render() seems to solve my problem without any wrapper, right?

it should, as long as you use only one env. You might need a bit more if you use multiple envs (you need to check if the images are concatenated or not).

zichunxx commented 6 months ago

You might need a bit more if you use multiple envs (you need to check if the images are concatenated or not).

I'll take note of what you said, thanks a lot.