Visual Observations and video recordings? #331

StoneT2000 commented 1 year ago

Just wondering if this is possible at the moment, it seems the old module is now gone in brax v2.

Moreover, if there is rendering support, are they GPU parallelizable or is it just state now (with single thread rendering for an eval env)


btaba commented 1 year ago

Hi @StoneT2000 we still don't have a jax renderer. We may add back (single threaded based on, but here's an implementation if you're looking to use it asap:

"""Exports a system config and state as an image."""

import io
from typing import List, Optional, Tuple

import brax
from brax import base
from brax import math
import jax
from jax import numpy as jp
import numpy as onp
from PIL import Image
from pytinyrenderer import TinyRenderCamera as Camera
from pytinyrenderer import TinyRenderLight as Light
from pytinyrenderer import TinySceneRenderer as Renderer

class TextureRGB888:

  def __init__(self, pixels):
    self.pixels = pixels
    self.width = int(onp.sqrt(len(pixels) / 3))
    self.height = int(onp.sqrt(len(pixels) / 3))

class Grid(TextureRGB888):

  def __init__(self, grid_size, color):
    grid = onp.zeros((grid_size, grid_size, 3), dtype=onp.int32)
    grid[:, :] = onp.array(color)
    grid[0] = onp.zeros((grid_size, 3), dtype=onp.int32)
    grid[:, 0] = onp.zeros((grid_size, 3), dtype=onp.int32)

_BASIC = TextureRGB888([133, 118, 102])
_TARGET = TextureRGB888([255, 34, 34])
_GROUND = Grid(100, [200, 200, 200])

def _flatten_vectors(vectors):
  """Returns the flattened array of the vectors."""
  return sum(map(lambda v: [v.x, v.y, v.z], vectors), [])

def _scene(sys: brax.System, state: brax.State) -> Tuple[Renderer, List[int]]:
  """Converts a brax System and state to a pytinyrenderer scene and instances."""
  scene = Renderer()
  instances = []

  # TODO: add mesh
  # mesh_geoms = [g for g in sys.geoms if isinstance(g, base.Mesh)]

  def take_i(obj, i):
    return jax.tree_map(lambda x: jp.take(x, i, axis=0), obj)

  link_names = [n or f'link {i}' for i, n in enumerate(sys.link_names)]
  link_names += ['world']
  link_geoms = {}
  for batch in sys.geoms:
    num_geoms = len(batch.friction)
    for i in range(num_geoms):
      link_idx = -1 if batch.link_idx is None else batch.link_idx[i]
      link_geoms.setdefault(link_names[link_idx], []).append(take_i(batch, i))

  for _, geom in link_geoms.items():
    for col in geom:
      tex = TextureRGB888((col.rgba[:3] * 255).astype('uint32'))
      if isinstance(col, base.Capsule):
        half_height = col.length / 2
        model = scene.create_capsule(col.radius, half_height, 2,
                                     tex.pixels, tex.width, tex.height)
      elif isinstance(col, base.Box):
        model = scene.create_cube(col.halfsize, tex.pixels, tex.width,
                                  tex.height, 16.)
      elif isinstance(col, base.Sphere):
        model = scene.create_capsule(col.radius, 0, 2, tex,
                                     tex.width, tex.height)
      elif isinstance(col, base.Plane):
        tex = _GROUND
        model = scene.create_cube([1000.0, 1000.0, 0.0001], tex.pixels,
                                  tex.width, tex.height, 8192)
      # elif col_type == 'mesh':
      #   mesh = col.mesh
      #   g = mesh_geoms[]
      #   scale = mesh.scale if mesh.scale else 1
      #   model = scene.create_mesh(
      #       onp.array(_flatten_vectors(g.vertices)) * scale,
      #       _flatten_vectors(g.vertex_normals), [0] * len(g.vertices) * 2,
      #       g.faces, tex.pixels, tex.width, tex.height, 1.)
        raise RuntimeError(f'unrecognized collider: {type(col)}')

      i = col.link_idx if col.link_idx is not None else -1
      x = state.x.concatenate(,)))
      instance = scene.create_object_instance(model)
      off = col.transform.pos
      pos = onp.array(x.pos[i]) + math.rotate(off, x.rot[i])
      rot = col.transform.rot
      rot = math.quat_mul(x.rot[i], rot)
      scene.set_object_position(instance, list(pos))
      scene.set_object_orientation(instance, [rot[1], rot[2], rot[3], rot[0]])

  return scene, instances

def _eye(sys: brax.System, state: brax.State) -> List[float]:
  """Determines the camera location for a Brax system."""
  parent_idx = jp.array(sys.link_parents)
  xj = state.x.vmap().do(
  dist = onp.linalg.norm(xj.pos - xj.take(parent_idx).pos, axis=1)
  dist *= parent_idx > -1
  dist = max(dist)
  off = [2 * dist, -2 * dist, dist]
  return list(state.x.pos[0, :] + onp.array(off))

def _up(unused_sys: brax.System) -> List[float]:
  """Determines the up orientation of the camera."""
  return [0, 0, 1]

def render_array(sys: brax.System,
                 state: brax.State,
                 width: int,
                 height: int,
                 light: Optional[Light] = None,
                 camera: Optional[Camera] = None,
                 ssaa: int = 2) -> onp.ndarray:
  """Renders an RGB array of a brax system and QP."""
  if (len(state.x.pos.shape), len(state.x.rot.shape)) != (2, 2):
    raise RuntimeError('unexpected shape in state')
  scene, instances = _scene(sys, state)
  target = state.x.pos[0, :]
  if light is None:
    direction = [0.57735, -0.57735, 0.57735]
    light = Light(
  if camera is None:
    eye, up = _eye(sys, state), _up(sys)
    hfov = 58.0
    vfov = hfov * height / width
    camera = Camera(
        viewWidth=width * ssaa,
        viewHeight=height * ssaa,
  img = scene.get_camera_image(instances, light, camera).rgb
  arr = onp.reshape(
      onp.array(img, dtype=onp.uint8),
      (camera.view_height, camera.view_width, -1))
  if ssaa > 1:
    arr = onp.asarray(Image.fromarray(arr).resize((width, height)))
  return arr

def render(sys: brax.System,
           states: List[brax.State],
           width: int,
           height: int,
           light: Optional[Light] = None,
           cameras: Optional[List[Camera]] = None,
           ssaa: int = 2,
           fmt='png') -> bytes:
  """Returns an image of a brax system and QP."""
  if not states:
    raise RuntimeError('must have at least one qp')
  if cameras is None:
    cameras = [None] * len(states)

  frames = [
          render_array(sys, state, width, height, light, camera, ssaa))
      for state, camera in zip(states, cameras)
  f = io.BytesIO()
  if len(frames) == 1:
    frames[0].save(f, format=fmt)
        duration=sys.dt * 1000,
  return f.getvalue()