Open yliu120 opened 4 months ago
@hawkinsp I chatted with Peter offline and I guess Peter has some ideas to improve the C API over this problem. Could you please share some thoughts here? Thanks so much.
Thank you for the issue, and the documentation.
https://github.com/openxla/xla/commit/5880fa3081bcfc0bed16f69f8cd78f48b7208b00 should fix this issue, You should be able to view graph viewer and memory viewer.
@cliveverghese I still have questions on the fix.
Basically a functional fix is easy. We can even just disable the PJRT C API for the GPU client from a user perspective and use the xla_extension.so
. However, with your fix, it seems to me that now we still have 2 singleton objects in two libraries.
I am confused with whether both of them collect all the info or simply one of them collects everything? Is there any mechanism to guarantee that there is always one object collects everything?
With the change,
We ensure that the output is a union of data collected by both of them, rather than picking 1 of them, which may contain only partial data. Prior to the change, 2 instance of XLADebugInfoManager existed, however, the profiler session was only collecting data from one of them, (Partial data).
After the change, The data is collected from both the instances.
Description
Hi JAX team,
We identify a bug with the JAX cuda plugin. Here is the writeup for the bug,
https://docs.google.com/document/d/1ldlD8XQ6XYX4zcSRCUIVQyAUBJQZX6v9PdE2qX2_FGw/edit?usp=sharing
To summarize,
We accidentally found that an object
XlaDebugInfoManager
which supposed to be a global singleton instance ends up with two copies in JAX code. The reason is that the singleton has been linked to bothxla_extension.so
andcuda_plugin.so
so that different part of the python code would reference to different copy.The direct consequence is that it leads to a few missing metadata in the profiler metadata and makes
jax.profiler
not function correctly.This is a bug report but also a feature request because we want to make sure anything intended to be global should not leak from the control of the C API. (A future safety mechanism)
System info (python version, jaxlib version, accelerator, etc.)
This is a general issue with JAX plugins. I tested on JAX latest release and HEAD.