Closed diegoferigo closed 3 months ago
On a MWE similar to the new test of this PR, the following is the output:
There are two piece of outputs. Before the line --------
there is what is seems the actual output of the callback (let's call it output 1) that raises the right type of exception, and after the line there is the XlaRuntimeError
exception that can be caught by the code (let's call it output 2).
Originally, in the test I was trying to capture output 1, but I couldn't find any way to do that within pytest (I've tried both by redirecting the std{out/err} streams to a buffer, and using the capsys
fixture). I suspect that the callback runs in a different thread o similar, making it impossible to catch its output (at least, I couldn't figure out a way.
@flferretti your suggestion in https://github.com/ami-iit/jaxsim/pull/181#discussion_r1643123536 makes sense, I didn't think about it because it could only catch the content of the XlaRuntimeError
, that is much longer than the original exception. However, as you can notice from the output above, it contains the content of the original exception. I'll update the tests to use that, since it is good enough for testing purpose. Thanks! In any case, I wanted to provide all this information here instead of the original comment in order to have better visibility for future readers.
This PR:
condition
(both in the callback and the low-leveljax.lax.cond
) is necessary because JAX compiles both branches and it would raise an exception while tracing.The caveat is that JAX raises a
XlaRuntimeError
to stop the execution of the jit-compiled function. The real exception raised in the callback is printed together with the corresponding stack trace earlier in the output.Although this method is not capable of handling raised exceptions with a
try
statement (I don't see any way to do that, regardless), at least we can stop the execution by raising.📚 Documentation preview 📚: https://jaxsim--181.org.readthedocs.build//181/