Open mattjj opened 4 months ago
I think this is an XLA bug because the JAX API just exposes what XLA returns in GetHloCostAnalysis.
FWIW, HloCostAnalysis
does return flops for HLO dot
ops:
https://github.com/openxla/xla/blob/16d9159215c2afa5c9e06efdc8e87bacfec239bb/xla/service/hlo_cost_analysis.cc#L407
Looks like the dot_general
is lowered to something else, maybe fused, or the data is lost in translation somewhere.
@mattjj Can you provide some more context on the use cases of HloCostAnalysis at the JAX level? I'm actually surprised it's exposed at that level.
@mattjj Can you provide some more context on the use cases of HloCostAnalysis at the JAX level? I'm actually surprised it's exposed at that level.
I don't know the detailed history, but in general users like being able to get things like FLOP count estimates and memory requirements, especially ahead-of-time without having to run anything. Maybe they're comparing algorithms, either as a human or as part of an automated search. See e.g. google/jax#3374. I can imagine it being quite handy!
Some folks have built jaxpr-level analyses for this (I know of at least one GDM-internal tool), though that may not be faithful to what the compiler ultimately generates (e.g. maybe the compiler decides to perform rematerialization).
In principle JAX+XLA should be good at providing this kind of feature, just by virtue of being staging- and compiler-oriented. And I think in principle the compiler (i.e. XLA) is the right place to compute such estimates, given that it ultimately generates the code, and that it may already need to form such estimates internally. If we want such info to include memory traffic then it seems necessary to be in the compiler. (But maybe my thinking is wrong!)
How do things look on your end?
Internally I think JAX/PjRt itself may use HloCostAnalysis for deciding whether to launch CPU computations asynchronously (see google/jax#9895), though that may be buggy on the JAX/PjRt end.
@thomasjoerg WDYT, is this a bug on XLA or a bug on JAX?
Good question, George. Without having done further debugging, I'd lean towards JAX bug, because XLA implements the HLO Cost Analysis for dot
ops. It may turn out the be a compiler bug once we look into it deeper, but we'd need to start from the JAX level.
I would wager what's going on here is that XLA:CPU is lowering to something like a customcall for which HloCostAnalysis doesn't work. Indeed we see a onednn customcall in the optimized HLO printed from adapting the last line of the code in the OP to be print(matmul.lower(A, B).compile().as_text())
:
HloModule jit_matmul, entry_computation_layout={(f32[2048,2048]{1,0}, f32[2048,2048]{1,0})->f32[2048,2048]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.4 (Arg_0.1: f32[2048,2048], Arg_1.2: f32[2048,2048]) -> f32[2048,2048] {
%Arg_0.1 = f32[2048,2048]{1,0} parameter(0), sharding={replicated}
%Arg_1.2 = f32[2048,2048]{1,0} parameter(1), sharding={replicated}
ROOT %custom-call = f32[2048,2048]{1,0} custom-call(f32[2048,2048]{1,0} %Arg_0.1, f32[2048,2048]{1,0} %Arg_1.2), custom_call_target="__onednn$matmul", metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/google/home/mattjj/packages/jax/tjoerg.py" source_line=5}, backend_config={"outer_dimension_partitions":[]}
}
I don't think it's a JAX plumbing issue because (1) this works on CPU for other HLO programs, and JAX plumbs the HloCostAnalysis result the same way regardless of what the HLO program is, and (2) this works for dot on other backends.
For (1) here's a different program running on XLA:CPU and producing a reasonable seeming FLOP count:
For (2) here's a screenshot of the same code running on a corp colab TPU:
To really verify it we could just call HloCostAnalysis directly (with no JAX plumbing) on the after-optimizations HLO of the dot computation in the OP. But I'm pretty convinced given the evidence I just listed.
So to me the real question isn't about whether it's a JAX or XLA bug, but rather it's whether you want to fix this XLA bug. I think it would be reasonable not to! After all, I'm not sure if you guys ever intended HloCostAnalysis to be a public API. Maybe it's just not a feature XLA (or at least XLA:CPU) wants to provide. That's totally reasonable, and in that case we might investigate trying to write our own cost analysis stuff to provide to users.
What do you think?
@thomasjoerg To chime in to what @mattjj has said, this would definitely be very useful for my team right now. We rely on estimated flop counts for debugging and optimization purposes. Can this XLA bug be fixed?
@olegshyshkov could you take a look?
I would wager what's going on here is that XLA:CPU is lowering to something like a customcall for which HloCostAnalysis doesn't work.
This is exactly where the -1
is coming from:
After all the lowering, the original dot
operation became a custom-call
. Unfortunately, there is no way to reason about performance of a given call, it can be anything. Some libraries that we call could potentially provide a Cost Model for their functions, but none of them do at the moment.
The information about the original instruction is also lost after all the lowering. There are only bits of information in the metadata
metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/google/home/mattjj/packages/jax/tjoerg.py" source_line=5}, backend_config={"outer_dimension_partitions":[]}
We could, potentially, extract matmuls dimension and apply the same math as in HandleDot, but there are two big downsides:
Right now, XLA:GPU Cost Model works reasonably well for fusions that go to Emitter (transpose, reduce, elementwise), but doesn't work for matmuls, yet.
Unfortunately, there is no way to reason about performance of a given call, it can be anything
Well, there is, we know exactly what custom calls are actually matmuls: we've lowered them to those in the first place.
Ping on this, any updated thoughts?
Nothing new on the front, but I noticed the same thing for FP8 matmuls (which have the particular aspect of potentially fusing quite a few pre/post processing ops: https://github.com/google/jax/issues/22313)
Is there a canonical way cost_analysis
metadata could be specified in a custom-call
HLO op? As part of the backend_config
dict? To me that would have 2 benefits:
amax
and gelu
epilogue fused);
Originally reported as google/jax#20184 and google/jax#16008:
I think this is an XLA bug because the JAX API just exposes what XLA returns in
GetHloCostAnalysis
.What do you think?