google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

New io.image rendering does not work for GymWrapper and inverted_pendulum #375

Closed JoeMWatson closed 10 months ago

JoeMWatson commented 1 year ago

Hi,

I was excited to try out the new rendering added recently but it doesn't seem to work in a few cases.

I'm using MacOS 13.4, Python 3.9.6, Brax 0.9.1 and Jax 0.4.10.

Running

import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.io import image
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import traceback

print(f"Using Brax {brax.__version__}, Jax {jax.__version__}")
print("From GymWrapper, env.reset()")
try:
    env = envs.create("inverted_pendulum",
                      batch_size=1,
                      episode_length=150,
                      backend='generalized')
    env = gym_wrapper.GymWrapper(env)
    env.reset()
    img = env.render(mode='rgb_array')
    plt.imshow(img)
except Exception:
    traceback.print_exc()

print("From GymWrapper, env.reset() and action")
try:
    env = envs.create("inverted_pendulum",
                      batch_size=1,
                      episode_length=150,
                      backend='generalized')
    env = gym_wrapper.GymWrapper(env)
    env.reset()
    action = jnp.zeros(env.action_space.shape)
    env.step(action)
    img = env.render(mode='rgb_array')
    plt.imshow(img)
except Exception:
    traceback.print_exc()

print("From brax env")
try:
    env = envs.create("inverted_pendulum",
                      batch_size=1,
                      episode_length=150,
                      backend='generalized')
    key = jax.random.PRNGKey(0)
    initial_env_state = env.reset(key)
    base_state = initial_env_state.pipeline_state
    pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel())
    img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256)
    print(f"pixel values: [{img.min()}, {img.max()}]")
    plt.imshow(img)
    plt.show()
except Exception:
    traceback.print_exc()

the output is

Using Brax 0.9.1 a Jax 0.4.10 From GymWrapper, env.reset() Traceback (most recent call last): File ".../brax_render_issue.py", line 19, in image = env.render(mode='rgb_array') File ".../brax/envs/wrappers/gym.py", line 85, in render return image.render_array(sys, state.state, 256, 256) AttributeError: 'State' object has no attribute 'state' From GymWrapper, env.reset() and action Traceback (most recent call last): File ".../brax_render_issue.py", line 34, in image = env.render(mode='rgb_array') File ".../brax/envs/wrappers/gym.py", line 85, in render return image.render_array(sys, state.state, 256, 256) AttributeError: 'State' object has no attribute 'state' From brax env pixel values: [255, 255]

So rendering with the GymWrapper doesn't seem to work with-or-without a simulation step, and calling the underlying render function returns a white image.

I ran this across all the envs and the GymWrapper issue was consistent, but the white image issue is specific to inverted_pendulum and the other envs rendered correctly. I did not check the dm_env wrapper.

btaba commented 10 months ago

Hi @JoeMWatson , thanks for the bug report! Just pushed a fix for this