Closed markusheimerl closed 1 month ago
You are 100% right. We are working on option 2.
@yuvaltassa Would you mind making the corresponding feature branch public in this repo? Id love to contribute.
The moment we have something that works it will be OSS and we would love you to contribute!
@erikfrey is leading this effort, perhaps there is something he'd like to add.
Hello! Please see #1604 and #1485 for related discussions. You can do visual observations today using mjx.ray
, although this only works for toy environments.
We are working on integrating Madrona as a means for high throughput tiled rendering on GPU, but this is still very much a work in progress. We'll share more once we have a good proof of concept - no ETA but this is actively under development.
@yuvaltassa @erikfrey thanks for your input! Have you considered https://github.com/JoeyTeng/jaxrenderer ?
import re
import jax
import numpy as onp
from PIL import Image
from jax import numpy as jp
from renderer import Model
from renderer import ModelObject
from renderer import LightParameters
from renderer.geometry import rotation_matrix
from renderer import CameraParameters
from renderer import ShadowParameters
from renderer import Renderer, transpose_for_display
from numpngw import write_apng
# Load model and textures
obj_path, texture_path, spec_path = "african_head.obj", "african_head_diffuse.tga", "african_head_spec.tga"
image = Image.open(texture_path)
width, height = image.size
texture = onp.zeros((width, height, 3))
for y in range(height):
for x in range(width):
texture[y, x] = onp.array(image.getpixel((x, y)))
texture = jp.array(texture, dtype=jp.single) / 255
image = Image.open(spec_path)
specular_map = onp.zeros((width, height, 3))
for y in range(height):
for x in range(width):
specular_map[y, x] = onp.array(image.getpixel((x, y)))
specular_map = jp.array(specular_map, dtype=jp.single)[..., 0]
verts, norms, uv, faces, faces_norm, faces_uv = [], [], [], [], [], []
_float, _integer, _one_vertex = re.compile(r"(-?\d+\.?\d*(?:e[+-]\d+)?)"), re.compile(r"\d+"), re.compile(r"\d+/\d*/\d*")
with open(obj_path, 'r') as file:
for line in file:
if line.startswith("v "):
verts.append(tuple(map(float, _float.findall(line, 2)[:3])))
elif line.startswith("vn "):
norms.append(tuple(map(float, _float.findall(line, 2)[:3])))
elif line.startswith("vt "):
uv.append(tuple(map(float, _float.findall(line, 2)[:2])))
elif line.startswith("f "):
face, face_norm, face_uv = [], [], []
vertices = _one_vertex.findall(line)
assert len(vertices) == 3, f"Expected 3 vertices, got {len(vertices)}"
for vertex in vertices:
v, vt, vn = list(map(int, _integer.findall(vertex)))
face.append(v - 1)
face_norm.append(vn - 1)
face_uv.append(vt - 1)
faces.append(face)
faces_norm.append(face_norm)
faces_uv.append(face_uv)
model = Model(
verts=jp.array(verts),
norms=jp.array(norms),
uvs=jp.array(uv),
faces=jp.array(faces),
faces_norm=jp.array(faces_norm),
faces_uv=jp.array(faces_uv),
diffuse_map=jax.numpy.swapaxes(texture, 0, 1)[:, ::-1, :],
specular_map=jax.numpy.swapaxes(specular_map, 0, 1)[:, ::-1],
)
canvas_width, canvas_height, frames, rotation_axis = 1920, 1080, 30, "Y"
rotation_axis = dict(X=(1., 0., 0.), Y=(0., 1., 0.), Z=(0., 0., 1.))[rotation_axis]
degrees = jax.lax.iota(float, frames) * 360. / frames
eye, center, up = jp.array((0, 0, 3.)), jp.array((0, 0, 0)), jp.array((0, 1, 0))
camera = CameraParameters(viewWidth=canvas_width, viewHeight=canvas_height, position=eye, target=center, up=up)
light = LightParameters(direction=jp.array([0.57735, -0.57735, 0.57735]), ambient=0.1, diffuse=0.85, specular=0.05)
shadow = ShadowParameters(centre=center)
@jax.default_matmul_precision("float32")
def render_instances(instances, width, height, camera, light, shadow):
img = Renderer.get_camera_image(objects=instances, light=light, camera=camera, width=width, height=height, shadow_param=shadow, colour_default=jp.zeros(3, dtype=jp.single))
return jax.lax.clamp(0., img, 1.)
def rotate(model, rotation_axis, degree):
instance = ModelObject(model=model)
return instance.replace_with_orientation(rotation_matrix=rotation_matrix(rotation_axis, degree))
batch_rotation = jax.jit(jax.vmap(lambda degree: rotate(model, rotation_axis, degree))).lower(degrees).compile()
instances = [batch_rotation(degrees)]
@jax.jit
def render(batched_instances):
def _render(instances):
_render = jax.jit(render_instances, static_argnames=("width", "height"), inline=True)
img = _render(instances=instances, width=canvas_width, height=canvas_height, camera=camera, light=light, shadow=shadow)
return transpose_for_display((img * 255).astype(jp.uint8))
return jax.jit(jax.vmap(_render))(batched_instances)
render_compiled = jax.jit(render).lower(instances).compile()
images = list(map(onp.asarray, jax.device_get(render_compiled(instances))))
write_apng('animation.png', images, delay=1/30.)
# ffmpeg -i animation.png intermediate.gif
# gifsicle --optimize=3 --delay=5 intermediate.gif > output.gif
All these views were rendered in parallel using jax as the only dependency:
Thank you for your input. Godspeed on integrating Madrona.
results in:
Description: I'm encountering an issue while using MuJoCo with JAX (mjx) for training a humanoid model in the Brax environment. The problem arises when attempting to render the environment state and retrieve camera images during training. The MuJoCo renderer does not seem to work properly when using mjx.
Problem Details: When calling the
render
method within the customHumanoid
class, an error is thrown:jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[28]
.This error indicates that the conversion method is called on a traced array, which depends on the value of the argument
state.pipeline_state.q
. The current implementation of the MuJoCo renderer in mjx does not handle the conversion of traced arrays properly in this context.Importance of Camera Input in Reinforcement Learning: Using camera input is crucial when training robots with reinforcement learning. In real-world scenarios, robots rely on visual information captured by their cameras to perceive and interact with the environment. By incorporating camera pixels as part of the observation space during training, the learned policies can be more robust and adaptable to real-world conditions.
Proposed Solution: To enable effective training with camera input using mjx, it is essential to address the compatibility issue between the MuJoCo renderer and JAX traced arrays. Possible solutions include:
Alternatives Considered:
Additional Context: Integrating camera input into reinforcement learning algorithms is crucial for developing intelligent and adaptable robots. By leveraging the power of JAX and mjx, researchers and developers can accelerate the training process and build more sophisticated models. However, the current compatibility issue between the MuJoCo renderer and JAX traced arrays hinders the effective utilization of camera input in this setup.
Addressing this issue and providing a seamless integration between the MuJoCo renderer and mjx will greatly benefit the robotics and reinforcement learning community. It will enable researchers to train models that can effectively process visual information, leading to more advanced and capable robots.
Thank you for considering this issue. Your support in resolving the compatibility problem and enhancing the usability of camera input with mjx will contribute to the advancement of robotics research and real-world applications.