jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Performance Issue in `jax.numpy.einsum` of Many Operands #7413

Open daskol opened 3 years ago

daskol commented 3 years ago

I was benchmarking effect of reusing einsum contraction path with opt_einsum package which is used internally in JAX (see code). There are three cases in the benchmark. The first one case uses jax.numpy.einsum directly (default). The second one case uses opt_einsum.contract_expression to make a reusable JIT-compiled contraction function (reuse). The third one case is similar to the second but a reusable contraction function is not JITed (reuse-nojit). Benchmarking results follow.

                     jit      time unit  count
    name
    default         True  1519.024   ms     50
    reuse           True  1545.984   ms     50
    reuse-nojit    False   306.832   ms     50

It seems that there is a performace issue in contraction of multiple tensors with jax.numpy.einsum.

Since GitHub does not allow to attach perf data file with statistics, perf stat reports are presented in case of using opt_einsum with/without JIT below.

$ perf stat report -i perf.reuse.data

 Performance counter stats for 'python bench-einsum.py reuse':

        229 869,71 msec task-clock:u              #    2,836 CPUs utilized
                ...............................................
    23 377 112 424      L1-dcache-loads:u         #  101,697 M/sec                      (62,50%)
    25 591 441 335      L1-dcache-load-misses:u   #  109,47% of all L1-dcache accesses  (62,52%)
    22 492 511 737      LLC-loads:u               #   97,849 M/sec                      (50,04%)
     2 726 440 928      LLC-load-misses:u         #   12,12% of all LL-cache accesses   (50,03%)

      81,045077552 seconds time elapsed

$ perf stat report -i perf.reuse-nojit.data

 Performance counter stats for 'python bench-einsum.py reuse-nojit':

        114 714,90 msec task-clock:u              #    6,504 CPUs utilized
                ...............................................
    10 610 602 864      L1-dcache-loads:u         #   92,495 M/sec                      (62,59%)
     2 467 347 851      L1-dcache-load-misses:u   #   23,25% of all L1-dcache accesses  (62,57%)
       177 916 305      LLC-loads:u               #    1,551 M/sec                      (49,95%)
       135 070 942      LLC-load-misses:u         #   75,92% of all LL-cache accesses   (49,94%)

      17,636926067 seconds time elapsed

One can see that CPU usage is in two times lesser in case of JITed contraction. A number of L1 cache misses is in magnitude greater for JITed summation as well. A brief examination of LLVM IR dumped with XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=xla-dump" env variable shows a code bloating typical for aggressive loop unrolling transformation.

NOTE All experiments were carried out on CPU. I guess that someone should check the issue on GPU/TPU.

zhangqiaorjc commented 3 years ago

@hawkinsp should we get an XLA/CPU person to look at the DotGeneral emitter for those cases?