Open ilemhadri opened 2 years ago
Unfortunately I think this is a known bug, not in JAX proper but lower in the tracing/profiling stack. It's on someone's todo list, but we don't have an ETA.
@sharadmv is that correct?
in particular, this happens inside lax.scan
as well.
Have you tried using jax.named_scope
? I suspect the issue is that the trace annotations don't work in a jit
.
BTW you can share Perfetto links by clicking Share
(not necessary to send screenshots)
Unfortunately, that does not seem to work. Re-running the trace with
@jax.jit
def compute(M):
Z = M@M
with jax.named_scope("InsideAnnotation"):
with jax.profiler.TraceAnnotation("InsideAnnotation"):
for _ in range(3):
Z = jnp.sqrt(Z@Z)
_, _ = jnp.linalg.eigh(Z)
return 0
dim = 300
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
result = jax.block_until_ready(compute(u@u))
returns a similar a trace.
Also, grepping the .json file for InsideAnnotation
does not yield any result (despite using named_scope
).
I think two things are going on here.
a) the profiler mostly isn't enabled for CPU workloads. Try on GPU for comparison. I acknowledge you're probably interested in CPU not GPU, but it might be a good way to rule this in or out.
b) TraceAnnotation
isn't lifted into jit
i.e., the annotation in essence would apply when the function is traced, not when the compiled function is run. The named_scope
is closer to what you want.
ok, so this time i'm running the following snippet (using named_scope
only) on a V100 GPU. Unfortunately, I don't observe meaningful changes to the profiling. another screenshot here.
I can see the calls to volta_dgemm_64x64
and cublas-batch-gemm
, but not the annotations.
@jax.jit
def compute(M):
Z = M@M
with jax.named_scope("InsideAnnotation"):
for _ in range(3):
Z = jnp.sqrt(Z@Z)
_, _ = jnp.linalg.eigh(Z)
return 0
dim = 300
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
result = jax.block_until_ready(compute(u@u))
BTW you can share Perfetto links by clicking Share (not necessary to send screenshots)
Not on my platform.. according to this link the share link inside Perfetto is only available to Googlers.
Any updates on this?
Bump, this seems important.
Thanks for the bump. Let me raise it in chat...
I haven't verified this myself yet, but some have reported that jax.named_scope
should appear in the trace when using the latest tensorboard. That is, after pip install --upgrade tf-nightly tbp-nightly
, and running tensorboard with something like
tensorboard --port 6006 --logdir <same logdir used to capture the profile>
then the trace viewer should include jax.named_scope
calls.
I confirmed that this works if you pip install tf-nightly-cpu tbp-nightly
.
I used versions:
tb-nightly 2.16.0a20231110
tf-nightly-cpu 2.16.0.dev20231110
The next release of TF/TB should contain the fix.
TF and TBP have had releases since @hawkinsp 's comment. Can someone test and verify?
Description
With this snippet of code,
I get the following trace on Perfetto: https://i.imgur.com/iTXI8Ic.png. So the trace only shows a
JaxCompiledFunction
whereas I would expect to seeInsideAnnotation
as well.This is the smallest reproducible example I could think of. In my use case, I have many more annotations and observe similar issues.
What jax/jaxlib version are you using?
0.3.17, jaxlib 0.3.15
Which accelerator(s) are you using?
CPU
Additional System Info
Python 3.8.13