clemisch / jaxtomo

Tomographic projector in JAX
1 stars 0 forks source link

Performance regression with JAX 0.4.32 on CPU #3

Open clemisch opened 2 months ago

clemisch commented 2 months ago

There is a performance regression in BP on CPU from JAX 0.4.31 to 0.4.32. The reason seems to be the new CPU backend with increased concurrency (https://github.com/google/jax/issues/23590).

Default behavior in JAX 0.4.32:

$ python3 timing.py --bp --fp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : True
bp       : True
size     : 128
dtype    : 'float32'
==== FP ====
(128, 128, 128) -> (128, 128, 128) :  1154 ms ,  0.55 µs per pixel , 0.002 GRays/s
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   679 ms ,  0.32 µs per voxel , 0.003 GRays/s
                                       ^^^

vs. manually deactivated CPU concurrency:

$ XLA_FLAGS=--xla_cpu_use_thunk_runtime=false python3 timing.py --bp --fp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : True
bp       : True
size     : 128
dtype    : 'float32'
==== FP ====
(128, 128, 128) -> (128, 128, 128) :  1231 ms ,  0.59 µs per pixel , 0.002 GRays/s
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   445 ms ,  0.21 µs per voxel , 0.005 GRays/s
                                       ^^^

BP takes 679ms vs 445ms.

ezhulenev commented 2 months ago

Can you dump HLO somewhere. One known problem is that small while loops got a lot more expensive, because we basically replaced compiler with interpreter, see a workaround I submitted for CPU backend: https://github.com/google/jax/commit/15d4389247c3b680de29000ad5c1e79670d1a7e0

Linux perf also should show where the CPU cycles are spent.

I'll try to reproduce it on my side, but with HLO dump I can do a lot easier.

penpornk commented 2 months ago

You can dump HLO by setting XLA_FLAGS=--xla_dump_to=/tmp/hlo (If you are using more than one flags, just add a space between each, e.g., XLA_FLAGS="--xla_dump_to=/tmp/hlo --xla_cpu_use_thunk_runtime=false"). Please zip all files in the dumped folder and upload them here?

while loop is one of the suspects. Another is oneDNN custom calls. They are not available in the new runtime yet so if your code has a lot of matmuls/convolutions, you may see some slowdowns.

clemisch commented 2 months ago

Thanks for the quick feedback!

$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_default" python3 timing.py --bp --size=128

produces hlo_default.zip.

$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_no_thunk --xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128

produces hlo_no_thunk.zip.

I'll try to look into perf in the meantime.


Edit: removed --fp to unclutter the HLO.

penpornk commented 2 months ago

Looks like this is because of my f64 tanh approximation commit: https://github.com/openxla/xla/commit/ae96f6eab49fbe95c4876069a29a2d2740f1b4d5

I'll either fix it or temporarily disable it before JAX 0.4.32 re-releases.

Edited to add: This comment was meant for a different issue, not this one.

clemisch commented 1 month ago

Very interesting! Thanks a lot for looking into it.

I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?

penpornk commented 1 month ago

I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?

Oops. Sorry about that! That comment was meant for a different issue (https://github.com/google/jax/issues/23590). I posted on the wrong tab.

Thank you for the HLO dump! I see one while loop in the main module (and no oneDNN custom calls in the no-thunk dump), so this regression may indeed be from while thunk. Our team will investigate soon. :)

clemisch commented 1 month ago

Ok, thanks for the clarification!

Still I was a bit surprised: I am not using while loops in the code used for timing.py, but scans. I was not aware that scan and while are lowered to the same HLO and (unconsciously) thinking that scan is more efficient. Now I see that the docs indeed state:

scan is a JAX primitive and is lowered to a single WhileOp

clemisch commented 1 month ago

jax(lib) 0.4.33 just dropped.

Performance is slightly worse than 0.4.32:

$ python3 timing.py --bp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : False
bp       : True
size     : 128
dtype    : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   703 ms ,  0.34 µs per voxel , 0.003 GRays/s
                                       ^^^

old CPU backend is same as before:

$ XLA_FLAGS="--xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : False
bp       : True
size     : 128
dtype    : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   442 ms ,  0.21 µs per voxel , 0.005 GRays/s
                                       ^^^
clemisch commented 1 month ago

@penpornk would you prefer I post this as a jax issue?

penpornk commented 1 month ago

jax(lib) 0.4.33 just dropped.

Yes, there were issues with 0.4.32 so the wheel was pulled off PyPI and 0.4.33 was a quick re-release.

would you prefer I post this as a jax issue?

Up to you. For the XLA:CPU team, the bug being here doesn't make a big difference. We have created an internal bug within our team and @pparuzel is looking at this issue.

For future bugs, it would be better to post on https://github.com/google/jax or https://github.com/openxla/xla, for more visibility. (And it could help JAX folks keep track of things they want to include in a new release.)

pparuzel commented 1 month ago

Looks like 92.7% of time is spent in fusion.clone. The WhileOp is a mere 3.6% of the total.

That would probably narrow down to:

  ROOT fusion.clone = f32[128,128,128]{2,1,0} fusion(p.1, p.2, p.3, p.4, p.5, /*index=5*/p.6, p.7, p.8), kind=kLoop, calls=fused_computation.clone, metadata={op_name="jit(get_bp)/jit(main)/while/body/add" source_file="/home/clem/git/jaxtomo/jaxtomo/projectors/cone_bp.py" source_line=71}, backend_config={"outer_dimension_partitions":["8"]}

I need to keep on digging to find the exact bottleneck.

clemisch commented 1 month ago

The performance impact seems to be somewhat inconsistent.

This is a table of the ratio time_default / time_nothunk i.e. new/old runtime. The cell Intel i7-8550U ∧ 128 is the ratio of my previous benchmark (703ms / 442ms).

--size= Intel i7-8550U Intel i5-1345U
128 1.59 (703/442) 1.21 (461/380)
256 1.60 (10625/6624) 1.07 (6479/6070)

Runtime on the old i7-8550U is generally higher, unsurprisingly. But the relative difference between old and new runtime is much smaller on the new CPU i5-1345U.

Could this be an effect only of cache sizes etc.?

pparuzel commented 1 month ago

We are noticing the new runtime generates less spmd instructions for this particular case. Therefore, the cache size might indeed explain the difference between these architectures. However, the root cause of the slowdown is still to be found.

Currently, the suspicion is that the codegen is missing some crucial LLVM metadata semantics which discourage optimizations through spmd in the thunk runtime.