jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.78k forks source link

Document JAX_TRACEBACK_FILTERING #18082

Open billmark opened 1 year ago

billmark commented 1 year ago

Description

JAX_TRACEBACK_FILTERING is not well documented. The only documentation I could find is in a changelog and it took a while to find that. All configuration options should be properly documented, ideally in one place.

The documentation should also explicitly specify that JAX_TRACEBACK_FILTERING is an environment variable (if indeed that is the case, which I believe it is).

As a bonus, it would be nice if there was a way to set configuration variables from python, since sometimes that's much more convenient than setting an environment variable. Is there a way to directly set the jax config variables? Alternatively, if I set os.environ['JAX_TRACEBACK_FILTERING']=blah, will that work? Do I have to do this before some initialization occurs or can I do it at any point in time?

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

We should document this better, sure. But the better question is: why did you feel the need to change it? The intention is you shouldn't need or want to do this unless you are debugging JAX itself.

jondeaton commented 6 months ago

The intention is you shouldn't need or want to do this unless you are debugging JAX itself.

I think this is only true from an idealized perspective- often when learning an unfamiliar portion of the JAX api its helpful to see exactly where the error originates from the JAX implementation, to help you deduce how you're using the API incorrectly. For instance when learning shard_map I was met with an error that was filtered to only show this:

ValueError: safe_zip() argument 2 is longer than argument 1

Not very helpful. Disabling traceback filtering I can see it comes from a mismatch with in_specs.

For those looking for how to turn traceback-filtering off:

jax.config.update("jax_traceback_filtering", "off")