jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

return in finally swallows exceptions #24541

Open iritkatriel opened 1 week ago

iritkatriel commented 1 week ago

Description

In https://github.com/jax-ml/jax/blob/e4eca9ec5975982f89903082532b49ec4d56da9d/jax/_src/util.py#L435 there is a return statement in a finally block, which would swallow any in-flight exception.

This means that if an unhandled exception (including a BaseException such as KeyboardInterrupt) is raised from the try body, or any exception is raised from an except: clause, it will not propagate on as expected.

If the intention is to suppress all exceptions, I would propose to make this clear by using "except BaseException".

See also https://docs.python.org/3/tutorial/errors.html#defining-clean-up-actions.

System info (python version, jaxlib version, accelerator, etc.)

NA

jakevdp commented 1 week ago

Thanks! This may be related to the issue reported in #18246