google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Persistent Caching of Jitted Functions on GPU for Brax Envs and Autodiff #394

Open bebark opened 9 months ago

bebark commented 9 months ago

Hi,

I am trying to use persistent caching in XLA on GPU to speed up the execution of my Brax code. Tracking issues in JAX this seems to be possible and I have confirmed it works for most functions on my side with no issue.

Unfortunately for my use case (getting the jacobian/hessian of the env.step wrt obs) my code exits prematurely without error when I call my persistently cached jacobian/hessian on subsequent code executions. This happens regardless of env and backend. I have included a minimal example to reproduce what I am seeing below. To reproduce my issue:

  1. Run minimal.py - entire program will execute and ./cache will be made and populated
  2. Run minimal.py again - second print (line 63) will not execute and program will quit prematurely

Main takeaways:

  1. Persistent caching of jacrev (the script default) applied to my step function wrapper fails
  2. Persistently cached hessian of my step function wrapper fails as well
  3. Persistent caching of jacfwd applied to my step function wrapper works without issue
  4. No issues when I don't jit (or jit without persistent caching) for any of the above use cases

Thanks for the help!

minimal.txt