Open clemisch opened 2 months ago
Can you dump HLO somewhere. One known problem is that small while loops got a lot more expensive, because we basically replaced compiler with interpreter, see a workaround I submitted for CPU backend: https://github.com/google/jax/commit/15d4389247c3b680de29000ad5c1e79670d1a7e0
Linux perf
also should show where the CPU cycles are spent.
I'll try to reproduce it on my side, but with HLO dump I can do a lot easier.
You can dump HLO by setting XLA_FLAGS=--xla_dump_to=/tmp/hlo
(If you are using more than one flags, just add a space between each, e.g., XLA_FLAGS="--xla_dump_to=/tmp/hlo --xla_cpu_use_thunk_runtime=false"
). Please zip all files in the dumped folder and upload them here?
while
loop is one of the suspects. Another is oneDNN custom calls. They are not available in the new runtime yet so if your code has a lot of matmuls/convolutions, you may see some slowdowns.
Thanks for the quick feedback!
$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_default" python3 timing.py --bp --size=128
produces hlo_default.zip.
$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_no_thunk --xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128
produces hlo_no_thunk.zip.
I'll try to look into perf
in the meantime.
Edit: removed --fp
to unclutter the HLO.
Looks like this is because of my f64 tanh approximation commit:
https://github.com/openxla/xla/commit/ae96f6eab49fbe95c4876069a29a2d2740f1b4d5
I'll either fix it or temporarily disable it before JAX 0.4.32 re-releases.
Edited to add: This comment was meant for a different issue, not this one.
Very interesting! Thanks a lot for looking into it.
I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?
I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?
Oops. Sorry about that! That comment was meant for a different issue (https://github.com/google/jax/issues/23590). I posted on the wrong tab.
Thank you for the HLO dump! I see one while
loop in the main module (and no oneDNN custom calls in the no-thunk dump), so this regression may indeed be from while
thunk. Our team will investigate soon. :)
Ok, thanks for the clarification!
Still I was a bit surprised: I am not using while
loops in the code used for timing.py
, but scan
s. I was not aware that scan
and while
are lowered to the same HLO and (unconsciously) thinking that scan
is more efficient. Now I see that the docs indeed state:
scan is a JAX primitive and is lowered to a single WhileOp
jax(lib) 0.4.33 just dropped.
Performance is slightly worse than 0.4.32:
$ python3 timing.py --bp --size=128
gpu : None
prealloc : False
pmap : False
fp : False
bp : True
size : 128
dtype : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) : 703 ms , 0.34 µs per voxel , 0.003 GRays/s
^^^
old CPU backend is same as before:
$ XLA_FLAGS="--xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128
gpu : None
prealloc : False
pmap : False
fp : False
bp : True
size : 128
dtype : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) : 442 ms , 0.21 µs per voxel , 0.005 GRays/s
^^^
@penpornk would you prefer I post this as a jax issue?
jax(lib) 0.4.33 just dropped.
Yes, there were issues with 0.4.32 so the wheel was pulled off PyPI and 0.4.33 was a quick re-release.
would you prefer I post this as a jax issue?
Up to you. For the XLA:CPU team, the bug being here doesn't make a big difference. We have created an internal bug within our team and @pparuzel is looking at this issue.
For future bugs, it would be better to post on https://github.com/google/jax or https://github.com/openxla/xla, for more visibility. (And it could help JAX folks keep track of things they want to include in a new release.)
Looks like 92.7% of time is spent in fusion.clone
. The WhileOp is a mere 3.6% of the total.
That would probably narrow down to:
ROOT fusion.clone = f32[128,128,128]{2,1,0} fusion(p.1, p.2, p.3, p.4, p.5, /*index=5*/p.6, p.7, p.8), kind=kLoop, calls=fused_computation.clone, metadata={op_name="jit(get_bp)/jit(main)/while/body/add" source_file="/home/clem/git/jaxtomo/jaxtomo/projectors/cone_bp.py" source_line=71}, backend_config={"outer_dimension_partitions":["8"]}
I need to keep on digging to find the exact bottleneck.
The performance impact seems to be somewhat inconsistent.
This is a table of the ratio time_default / time_nothunk
i.e. new/old runtime. The cell Intel i7-8550U ∧ 128 is the ratio of my previous benchmark (703ms / 442ms).
--size= | Intel i7-8550U | Intel i5-1345U |
---|---|---|
128 | 1.59 (703/442 ) |
1.21 (461/380 ) |
256 | 1.60 (10625/6624 ) |
1.07 (6479/6070 ) |
Runtime on the old i7-8550U is generally higher, unsurprisingly. But the relative difference between old and new runtime is much smaller on the new CPU i5-1345U.
Could this be an effect only of cache sizes etc.?
We are noticing the new runtime generates less spmd instructions for this particular case. Therefore, the cache size might indeed explain the difference between these architectures. However, the root cause of the slowdown is still to be found.
Currently, the suspicion is that the codegen is missing some crucial LLVM metadata semantics which discourage optimizations through spmd in the thunk runtime.
There is a performance regression in BP on CPU from JAX 0.4.31 to 0.4.32. The reason seems to be the new CPU backend with increased concurrency (https://github.com/google/jax/issues/23590).
Default behavior in JAX 0.4.32:
vs. manually deactivated CPU concurrency:
BP takes 679ms vs 445ms.