jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.25k stars 2.77k forks source link

TraceAnnotation not showing inside jax.jit #12381

Open ilemhadri opened 2 years ago

ilemhadri commented 2 years ago

Description

With this snippet of code,

import jax
import jax.numpy as jnp

@jax.jit
def compute(M):
    Z = M@M
    with jax.profiler.TraceAnnotation("InsideAnnotation"):
        for _ in range(10):
            Z = jnp.sqrt(Z@Z)
            _, _ = jnp.linalg.eigh(Z)
    return 0

dim = 1000
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
    result = jax.block_until_ready(compute(u@u))

I get the following trace on Perfetto: https://i.imgur.com/iTXI8Ic.png. So the trace only shows a JaxCompiledFunction whereas I would expect to see InsideAnnotation as well.

This is the smallest reproducible example I could think of. In my use case, I have many more annotations and observe similar issues.

What jax/jaxlib version are you using?

0.3.17, jaxlib 0.3.15

Which accelerator(s) are you using?

CPU

Additional System Info

Python 3.8.13

mattjj commented 2 years ago

Unfortunately I think this is a known bug, not in JAX proper but lower in the tracing/profiling stack. It's on someone's todo list, but we don't have an ETA.

@sharadmv is that correct?

ilemhadri commented 2 years ago

in particular, this happens inside lax.scan as well.

sharadmv commented 2 years ago

Have you tried using jax.named_scope? I suspect the issue is that the trace annotations don't work in a jit.

BTW you can share Perfetto links by clicking Share (not necessary to send screenshots)

ilemhadri commented 2 years ago

Unfortunately, that does not seem to work. Re-running the trace with

@jax.jit
def compute(M):
    Z = M@M
    with jax.named_scope("InsideAnnotation"):
            with jax.profiler.TraceAnnotation("InsideAnnotation"):
                for _ in range(3):
                    Z = jnp.sqrt(Z@Z)
                    _, _ = jnp.linalg.eigh(Z)
    return 0

dim = 300
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
    result = jax.block_until_ready(compute(u@u))

returns a similar a trace.

Also, grepping the .json file for InsideAnnotation does not yield any result (despite using named_scope).

hawkinsp commented 2 years ago

I think two things are going on here. a) the profiler mostly isn't enabled for CPU workloads. Try on GPU for comparison. I acknowledge you're probably interested in CPU not GPU, but it might be a good way to rule this in or out. b) TraceAnnotation isn't lifted into jit i.e., the annotation in essence would apply when the function is traced, not when the compiled function is run. The named_scope is closer to what you want.

ilemhadri commented 2 years ago

ok, so this time i'm running the following snippet (using named_scope only) on a V100 GPU. Unfortunately, I don't observe meaningful changes to the profiling. another screenshot here.

I can see the calls to volta_dgemm_64x64 and cublas-batch-gemm, but not the annotations.

@jax.jit
def compute(M):
    Z = M@M
    with jax.named_scope("InsideAnnotation"):
            for _ in range(3):
                Z = jnp.sqrt(Z@Z)
                _, _ = jnp.linalg.eigh(Z)
    return 0

dim = 300
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
    result = jax.block_until_ready(compute(u@u))

BTW you can share Perfetto links by clicking Share (not necessary to send screenshots)

Not on my platform.. according to this link the share link inside Perfetto is only available to Googlers.

EdanToledo commented 1 year ago

Any updates on this?

hr0nix commented 1 year ago

Bump, this seems important.

mattjj commented 1 year ago

Thanks for the bump. Let me raise it in chat...

mattjj commented 11 months ago

I haven't verified this myself yet, but some have reported that jax.named_scope should appear in the trace when using the latest tensorboard. That is, after pip install --upgrade tf-nightly tbp-nightly, and running tensorboard with something like

tensorboard --port 6006 --logdir <same logdir used to capture the profile>

then the trace viewer should include jax.named_scope calls.

hawkinsp commented 11 months ago

I confirmed that this works if you pip install tf-nightly-cpu tbp-nightly.

I used versions:

tb-nightly                   2.16.0a20231110
tf-nightly-cpu               2.16.0.dev20231110

The next release of TF/TB should contain the fix.

mattjj commented 7 months ago

TF and TBP have had releases since @hawkinsp 's comment. Can someone test and verify?

j-towns commented 6 months ago

Surely jax.named_scope and jax.named_call ought to be mentioned here in the docs? I'm happy to write a pr if helpful.