google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

io/image: Replace PyTinyrenderer with a Pure JAX renderer #367

Open JoeyTeng opened 1 year ago

JoeyTeng commented 1 year ago

This solves #331 #108 #47 #67 and is involved in #302

Example Colab (adopt majority of the code from Brax teams's Brax Training)

Known issue: the plane may not display in full correctly, due to unclipped vertices behind/at the camera plane. This will be solved once a proper clipping algorithm is implemented in JaxRenderer (will be delivered soon). The plane rendering issue is now solved by implementing a rasterisation based on homogeneous interpolation.

Currently this only make changes in the v2 API, with minimal changes. Feel free to tell me if you think it is needed to back port to v1 API as well. I will try to optimise the performance further later.

Update:

  1. with 8a111d9, I jitted the rendering for each frame when renders a batch of states.
  2. 828848a refactor batch rendering a bit but with no significant improvement in performance (if there is any), see simple benchmark here
  3. with JoeyTeng/jaxrenderer#2 (release 0.3.0), the performance is improved to about 10x. Rendering one frame of the Ant environment with 960x540 resolution and 1x SSAA is now about 500ms using T4 (0.5 fps), and ~70-100ms using A100 (10-14 fps). Previous CPU implementation is about 200ms per frame (5fps).
  4. with JoeyTeng/jaxrenderer#3 (release 0.3.1), the minimum Python version is lowered to Python 3.8 which is the same as brax.
JoeyTeng commented 1 year ago

Friendly ping @erikfrey . With the release of jaxrenderer 0.3.0, I believe this should be a suitable replacement for current CPU renderer pytinyrenderer as 1) it is a pure JAX implementation; and 2) it performs better on high-end GPUs like A100 (>2x speedup). Let me know what you think :)

erikfrey commented 1 year ago

Hi Joey - thanks for this tremendous effort! Unfortunately we're bound by certain restrictions that make it very difficult for us to depend on external packages unless they go through an arduous vetting process behind the scenes.

I'm curious how your approach performs on CPU?

You've probably already found that using XLA on its own limits your performance here, as XLA wasn't designed for rasterization-like operations and it does not know how to use the underlying GPU primitives to actually do this performantly.

If you're interested, there is a way to register custom XLA ops, so that you could call into CUDA's rasterization functions. Tensorflow graphics does such a thing:

https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/rendering/opengl/rasterizer_op.cc

You could then make such an op accessible to JAX. This would be very fast, but also quite tricky to get all the plumbing to work. Just thought I'd put it out there in case you're curious to hack on such a thing.

Either way though, it would be hard for us to accept this PR due to the policies that control which external packages we can rely on. Sorry about that! But I'll leave the issue open for a while - I'd love to hear if you continue hacking on this.

JoeyTeng commented 1 year ago

Thank you so much for your suggestions! I am also thinking about customised op as that would dramatically improve the performance.

For the current performance, using high-end GPUs the rendering will actually be faster than CPU; current implementation is not optimised for TPU so the performance is really bad on TPUs, both for execution time and memory usage (up to 98% padding...). I will benchmark and improve the performance over a batch of small images (e.g. 84x84, as this is used in RL for Atori environment) to see if my implementation could benefit from rendering batches of environments in parallel (although we could definitely benefit if we are using a TPU Pod and pmap/xmap all environment simulation + rendering over devices).

sai-prasanna commented 3 months ago

How about brax-contrib package or something like that where this can be added?