I was benchmarking effect of reusing einsum contraction path with opt_einsum package which is used internally in JAX (see code). There are three cases in the benchmark. The first one case uses jax.numpy.einsum directly (default). The second one case uses opt_einsum.contract_expression to make a reusable JIT-compiled contraction function (reuse). The third one case is similar to the second but a reusable contraction function is not JITed (reuse-nojit). Benchmarking results follow.
jit time unit count
name
default True 1519.024 ms 50
reuse True 1545.984 ms 50
reuse-nojit False 306.832 ms 50
It seems that there is a performace issue in contraction of multiple tensors with jax.numpy.einsum.
Since GitHub does not allow to attach perf data file with statistics, perf stat reports are presented in case of using opt_einsum with/without JIT below.
$ perf stat report -i perf.reuse.data
Performance counter stats for 'python bench-einsum.py reuse':
229 869,71 msec task-clock:u # 2,836 CPUs utilized
...............................................
23 377 112 424 L1-dcache-loads:u # 101,697 M/sec (62,50%)
25 591 441 335 L1-dcache-load-misses:u # 109,47% of all L1-dcache accesses (62,52%)
22 492 511 737 LLC-loads:u # 97,849 M/sec (50,04%)
2 726 440 928 LLC-load-misses:u # 12,12% of all LL-cache accesses (50,03%)
81,045077552 seconds time elapsed
$ perf stat report -i perf.reuse-nojit.data
Performance counter stats for 'python bench-einsum.py reuse-nojit':
114 714,90 msec task-clock:u # 6,504 CPUs utilized
...............................................
10 610 602 864 L1-dcache-loads:u # 92,495 M/sec (62,59%)
2 467 347 851 L1-dcache-load-misses:u # 23,25% of all L1-dcache accesses (62,57%)
177 916 305 LLC-loads:u # 1,551 M/sec (49,95%)
135 070 942 LLC-load-misses:u # 75,92% of all LL-cache accesses (49,94%)
17,636926067 seconds time elapsed
One can see that CPU usage is in two times lesser in case of JITed contraction. A number of L1 cache misses is in magnitude greater for JITed summation as well. A brief examination of LLVM IR dumped with XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=xla-dump" env variable shows a code bloating typical for aggressive loop unrolling transformation.
NOTE All experiments were carried out on CPU. I guess that someone should check the issue on GPU/TPU.
I was benchmarking effect of reusing
einsum
contraction path withopt_einsum
package which is used internally in JAX (see code). There are three cases in the benchmark. The first one case usesjax.numpy.einsum
directly (default
). The second one case usesopt_einsum.contract_expression
to make a reusable JIT-compiled contraction function (reuse
). The third one case is similar to the second but a reusable contraction function is not JITed (reuse-nojit
). Benchmarking results follow.It seems that there is a performace issue in contraction of multiple tensors with
jax.numpy.einsum
.Since GitHub does not allow to attach
perf
data file with statistics,perf stat
reports are presented in case of usingopt_einsum
with/without JIT below.One can see that CPU usage is in two times lesser in case of JITed contraction. A number of L1 cache misses is in magnitude greater for JITed summation as well. A brief examination of LLVM IR dumped with
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=xla-dump"
env variable shows a code bloating typical for aggressive loop unrolling transformation.NOTE All experiments were carried out on CPU. I guess that someone should check the issue on GPU/TPU.