google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Rendering an image with brax.io.image #460

Closed eleninisioti closed 4 months ago

eleninisioti commented 4 months ago

I have seen this functionality mentioned in other issues, but I cannot find code that works. Here is my code that attempts to get an image of the environment at every step:

import jax.numpy as jnp
from brax import envs
from brax.io import html
from brax.io import image
from jax import random
import jax

if __name__ == "__main__":

    env_name = "ant"
    backend = "generalized"
    episode_length = 5

    env = envs.get_environment(backend=backend, env_name=env_name)

    state = jax.jit(env.reset)(random.PRNGKey(0))

    cum_reward = 0
    states = []
    for step in range(1, episode_length + 1):

        action = jnp.array([0]*env.action_size)
        state = jax.jit(env.step)(state, action)
        states.append(state.pipeline_state)
        cum_reward += state.reward

        # this line gives the error. comment it  out to get the html video
        step_image = image.render(env.sys, state)

        if state.done:
            break

    render = html.render(env.sys, states)

    with open("traj.html", "w") as f:
        f.write(render)

    print("Episode ended. Total reward: " + str(cum_reward))

This throws an error

  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 34, in render_array
    renderer = mujoco.Renderer(sys.mj_model, height=height, width=width)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 83, in __init__
    self._mjr_context = _render.MjrContext(
                        ^^^^^^^^^^^^^^^^^^^
mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Exception ignored in: <function Renderer.__del__ at 0x7fa3777ec860>
Traceback (most recent call last):
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 330, in __del__
    self.close()
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 318, in close
    if self._mjr_context:
       ^^^^^^^^^^^^^^^^^
AttributeError: 'Renderer' object has no attribute '_mjr_context'

If you comment out the line calling image.render, you will see that html.render works, so I am not sure why one works and the other does not.

I am using brax 0.10.0 and mujoco 3.1.2.

jamesheald commented 4 months ago

You need to install dependencies for rendering. See https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html ('Prerequisite for rendering').

eleninisioti commented 4 months ago

Thank you! I installed the dependencies with sudo apt-get install libglfw3 libgl1-mesa-glx libosmesa6 and conda install -c conda-forge glew (because apt get could not find package libglew2.0).

I am not getting that error any more but I am getting this one

    step_image = image.render(env.sys, states)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 66, in render
    frames[0].save(f, format=fmt)
    ^^^^^^^^^^^^^^
    AttributeError: 'numpy.ndarray' object has no attribute 'save'. Did you mean: 'ravel'?

My numpy version is 1.24.3

jamesheald commented 4 months ago

I also encountered the same save error (using brax 0.10.0 and mujoco 3.1.2).

From looking at image.py in brax.io in the previous version of brax (0.9.4), it seems that each element of frames needs to be wrapped in Image.fromarray(....) before save is called. For example:

from PIL import Image

....

Image.fromarray(frames[0])

This looks like a bug to me.

btaba commented 4 months ago

Thanks for the bug report and finding the issue, we'll push out a fix

btaba commented 4 months ago

This should be fixed in 0825bcb74b53e36a62d50405a85e540fe2c25a95 , please let us know if that doesn't work! Closing for now

jamesheald commented 4 months ago

There's another issue. The renderings are not in color. Here are some example gifs and the code used to produce them.

reacher ant

import jax.numpy as jnp
from brax import envs
from brax.io import image
from jax import random
import jax
from IPython.display import Image

env_name = "reacher"
backend = "generalized"
episode_length = 5

env = envs.get_environment(backend=backend, env_name=env_name)

state = jax.jit(env.reset)(random.PRNGKey(0))

cum_reward = 0
states = []
rollout = []
for step in range(1, episode_length + 1):

    rollout.append(state.pipeline_state)

    action = jnp.array([0]*env.action_size)
    state = jax.jit(env.step)(state, action)
    states.append(state.pipeline_state)
    cum_reward += state.reward

    if state.done:
        break

gif = Image(image.render(env.sys, rollout, fmt = 'gif'))
open('reacher.gif', 'wb').write(gif.data)