ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
69 stars 10 forks source link

Raise exceptions from jit-compiled functions #181

Closed diegoferigo closed 3 months ago

diegoferigo commented 3 months ago

This PR:

The caveat is that JAX raises a XlaRuntimeError to stop the execution of the jit-compiled function. The real exception raised in the callback is printed together with the corresponding stack trace earlier in the output.

Although this method is not capable of handling raised exceptions with a try statement (I don't see any way to do that, regardless), at least we can stop the execution by raising.


📚 Documentation preview 📚: https://jaxsim--181.org.readthedocs.build//181/

diegoferigo commented 3 months ago

On a MWE similar to the new test of this PR, the following is the output:

Output of raising exceptions in a jax callback ``` jax.debug_callback failed Traceback (most recent call last): File "/jaxsim/lib/python3.12/site-packages/jax/_src/debugging.py", line 84, in debug_callback_impl callback(*args) File "/jaxsim/lib/python3.12/site-packages/jax/_src/debugging.py", line 246, in _flat_callback callback(*args, **kwargs) File "/home/dferigo/git/jaxsim/src/jaxsim/exceptions.py", line 46, in _raise_exception raise exception(msg.format(*args, **kwargs)).with_traceback(back_tb) File "/jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2486, in _wrapped_callback out_vals = callback(*args) ^^^^^^^^^^^^^^^ ValueError: This is a test exception for 42 data --------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) Cell In[1], line 81 77 return data 80 # _ = jit_compiled_function(data=40) ---> 81 _ = jit_compiled_function(data=42) [... skipping hidden 10 frame] File /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1178, in ExecuteReplicated.__call__(self, *args) 1175 if (self.ordered_effects or self.has_unordered_effects 1176 or self.has_host_callbacks): 1177 input_bufs = self._add_tokens_to_inputs(input_bufs) -> 1178 results = self.xla_executable.execute_sharded( 1179 input_bufs, with_tokens=True 1180 ) 1181 result_token_bufs = results.disassemble_prefix_into_single_device_arrays( 1182 len(self.ordered_effects)) 1183 sharded_runtime_token = results.consume_token() XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: This is a test exception for 42 data At: /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py(2486): _wrapped_callback /jaxsim/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1178): __call__ /jaxsim/lib/python3.12/site-packages/jax/_src/profiler.py(335): wrapper /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1488): _pjit_call_impl_python /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1534): call_impl_cache_miss /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(1558): _pjit_call_impl /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(879): process_primitive /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(391): bind_with_trace /jaxsim/lib/python3.12/site-packages/jax/_src/core.py(2789): bind /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(182): _python_pjit_helper /jaxsim/lib/python3.12/site-packages/jax/_src/pjit.py(305): cache_miss /jaxsim/lib/python3.12/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback (81): /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3577): run_code /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async /jaxsim/lib/python3.12/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3130): _run_cell /jaxsim/lib/python3.12/site-packages/IPython/core/interactiveshell.py(3075): run_cell /jaxsim/lib/python3.12/site-packages/IPython/terminal/interactiveshell.py(910): interact /jaxsim/lib/python3.12/site-packages/IPython/terminal/interactiveshell.py(917): mainloop /jaxsim/lib/python3.12/site-packages/IPython/terminal/ipapp.py(317): start /jaxsim/lib/python3.12/site-packages/traitlets/config/application.py(1075): launch_instance /jaxsim/lib/python3.12/site-packages/IPython/__init__.py(130): start_ipython /jaxsim/bin/ipython(10): ```

There are two piece of outputs. Before the line -------- there is what is seems the actual output of the callback (let's call it output 1) that raises the right type of exception, and after the line there is the XlaRuntimeError exception that can be caught by the code (let's call it output 2).

Originally, in the test I was trying to capture output 1, but I couldn't find any way to do that within pytest (I've tried both by redirecting the std{out/err} streams to a buffer, and using the capsys fixture). I suspect that the callback runs in a different thread o similar, making it impossible to catch its output (at least, I couldn't figure out a way.

@flferretti your suggestion in https://github.com/ami-iit/jaxsim/pull/181#discussion_r1643123536 makes sense, I didn't think about it because it could only catch the content of the XlaRuntimeError, that is much longer than the original exception. However, as you can notice from the output above, it contains the content of the original exception. I'll update the tests to use that, since it is good enough for testing purpose. Thanks! In any case, I wanted to provide all this information here instead of the original comment in order to have better visibility for future readers.