jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Global Singleton (XlaDebugInfoManager) leaks out of the control of C API and gets two copies in two shared libraries #22148

Open yliu120 opened 4 months ago

yliu120 commented 4 months ago

Description

Hi JAX team,

We identify a bug with the JAX cuda plugin. Here is the writeup for the bug,

https://docs.google.com/document/d/1ldlD8XQ6XYX4zcSRCUIVQyAUBJQZX6v9PdE2qX2_FGw/edit?usp=sharing

To summarize,

We accidentally found that an object XlaDebugInfoManager which supposed to be a global singleton instance ends up with two copies in JAX code. The reason is that the singleton has been linked to both xla_extension.so and cuda_plugin.so so that different part of the python code would reference to different copy.

The direct consequence is that it leads to a few missing metadata in the profiler metadata and makes jax.profiler not function correctly.

This is a bug report but also a feature request because we want to make sure anything intended to be global should not leak from the control of the C API. (A future safety mechanism)

System info (python version, jaxlib version, accelerator, etc.)

This is a general issue with JAX plugins. I tested on JAX latest release and HEAD.

yliu120 commented 4 months ago

@hawkinsp I chatted with Peter offline and I guess Peter has some ideas to improve the C API over this problem. Could you please share some thoughts here? Thanks so much.

cliveverghese commented 4 months ago

Thank you for the issue, and the documentation.

https://github.com/openxla/xla/commit/5880fa3081bcfc0bed16f69f8cd78f48b7208b00 should fix this issue, You should be able to view graph viewer and memory viewer.

yliu120 commented 3 months ago

@cliveverghese I still have questions on the fix.

Basically a functional fix is easy. We can even just disable the PJRT C API for the GPU client from a user perspective and use the xla_extension.so. However, with your fix, it seems to me that now we still have 2 singleton objects in two libraries.

I am confused with whether both of them collect all the info or simply one of them collects everything? Is there any mechanism to guarantee that there is always one object collects everything?

cliveverghese commented 3 months ago

With the change,

We ensure that the output is a union of data collected by both of them, rather than picking 1 of them, which may contain only partial data. Prior to the change, 2 instance of XLADebugInfoManager existed, however, the profiler session was only collecting data from one of them, (Partial data).

After the change, The data is collected from both the instances.