openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.46k stars 374 forks source link

[xla:cpu] [xla:gpu] DotGeneral ignored in GetHloCostAnalysis FLOPs #10479

Open mattjj opened 4 months ago

mattjj commented 4 months ago

Originally reported as google/jax#20184 and google/jax#16008:

Description

When inspecting the estimated flop count of a compiled function, dot_general, einsum, '@', jnp.dot, etc show "-1.0" flops.

import jax

@jax.jit
def matmul(A, B):
    return jax.lax.dot_general(A, B,(((A.ndim - 1,), (B.ndim - 2,)), ((), ())))

key1, key2 = jax.random.split(jax.random.PRNGKey(0))
A = jax.random.normal(key1, shape=(2048, 2048))
B = jax.random.normal(key2, shape=(2048, 2048))
print(matmul.lower(A, B).compile().cost_analysis()[0]['flops'])

returns "-1.0"

Swap out jax.lax.dot_general (and the relevant args) for jax.numpy.einsum, jax.numpy.matmul, or even '@' and the same thing happens. This also occurs both on CPU and GPU.

System info (python version, jaxlib version, accelerator, etc.)

jaxlib: 0.4.24
numpy:  1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
$ nvidia-smi
Mon Mar 11 16:48:49 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0 Off |                  Off |
|  0%   36C    P5              35W / 480W |  23320MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    980498      G   /usr/lib/xorg/Xorg                          183MiB |
|    0   N/A  N/A    980710      G   /usr/bin/gnome-shell                         12MiB |
|    0   N/A  N/A   1745218      C   python                                    10810MiB |
|    0   N/A  N/A   1760028      C   python                                    12220MiB |
|    0   N/A  N/A   4129416      G   ...seed-version=20240306-180105.838000       55MiB |
+---------------------------------------------------------------------------------------+

I think this is an XLA bug because the JAX API just exposes what XLA returns in GetHloCostAnalysis.

What do you think?

thomasjoerg commented 4 months ago

I think this is an XLA bug because the JAX API just exposes what XLA returns in GetHloCostAnalysis.

FWIW, HloCostAnalysis does return flops for HLO dot ops: https://github.com/openxla/xla/blob/16d9159215c2afa5c9e06efdc8e87bacfec239bb/xla/service/hlo_cost_analysis.cc#L407

Looks like the dot_general is lowered to something else, maybe fused, or the data is lost in translation somewhere.

@mattjj Can you provide some more context on the use cases of HloCostAnalysis at the JAX level? I'm actually surprised it's exposed at that level.

mattjj commented 4 months ago

@mattjj Can you provide some more context on the use cases of HloCostAnalysis at the JAX level? I'm actually surprised it's exposed at that level.

I don't know the detailed history, but in general users like being able to get things like FLOP count estimates and memory requirements, especially ahead-of-time without having to run anything. Maybe they're comparing algorithms, either as a human or as part of an automated search. See e.g. google/jax#3374. I can imagine it being quite handy!

Some folks have built jaxpr-level analyses for this (I know of at least one GDM-internal tool), though that may not be faithful to what the compiler ultimately generates (e.g. maybe the compiler decides to perform rematerialization).

In principle JAX+XLA should be good at providing this kind of feature, just by virtue of being staging- and compiler-oriented. And I think in principle the compiler (i.e. XLA) is the right place to compute such estimates, given that it ultimately generates the code, and that it may already need to form such estimates internally. If we want such info to include memory traffic then it seems necessary to be in the compiler. (But maybe my thinking is wrong!)

How do things look on your end?

mattjj commented 4 months ago

Internally I think JAX/PjRt itself may use HloCostAnalysis for deciding whether to launch CPU computations asynchronously (see google/jax#9895), though that may be buggy on the JAX/PjRt end.

cheshire commented 4 months ago

@thomasjoerg WDYT, is this a bug on XLA or a bug on JAX?

thomasjoerg commented 4 months ago

Good question, George. Without having done further debugging, I'd lean towards JAX bug, because XLA implements the HLO Cost Analysis for dot ops. It may turn out the be a compiler bug once we look into it deeper, but we'd need to start from the JAX level.

mattjj commented 4 months ago

I would wager what's going on here is that XLA:CPU is lowering to something like a customcall for which HloCostAnalysis doesn't work. Indeed we see a onednn customcall in the optimized HLO printed from adapting the last line of the code in the OP to be print(matmul.lower(A, B).compile().as_text()):

HloModule jit_matmul, entry_computation_layout={(f32[2048,2048]{1,0}, f32[2048,2048]{1,0})->f32[2048,2048]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.4 (Arg_0.1: f32[2048,2048], Arg_1.2: f32[2048,2048]) -> f32[2048,2048] {
  %Arg_0.1 = f32[2048,2048]{1,0} parameter(0), sharding={replicated}
  %Arg_1.2 = f32[2048,2048]{1,0} parameter(1), sharding={replicated}
  ROOT %custom-call = f32[2048,2048]{1,0} custom-call(f32[2048,2048]{1,0} %Arg_0.1, f32[2048,2048]{1,0} %Arg_1.2), custom_call_target="__onednn$matmul", metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/google/home/mattjj/packages/jax/tjoerg.py" source_line=5}, backend_config={"outer_dimension_partitions":[]}
}

I don't think it's a JAX plumbing issue because (1) this works on CPU for other HLO programs, and JAX plumbs the HloCostAnalysis result the same way regardless of what the HLO program is, and (2) this works for dot on other backends.

For (1) here's a different program running on XLA:CPU and producing a reasonable seeming FLOP count:

image

For (2) here's a screenshot of the same code running on a corp colab TPU:

image

To really verify it we could just call HloCostAnalysis directly (with no JAX plumbing) on the after-optimizations HLO of the dot computation in the OP. But I'm pretty convinced given the evidence I just listed.

So to me the real question isn't about whether it's a JAX or XLA bug, but rather it's whether you want to fix this XLA bug. I think it would be reasonable not to! After all, I'm not sure if you guys ever intended HloCostAnalysis to be a public API. Maybe it's just not a feature XLA (or at least XLA:CPU) wants to provide. That's totally reasonable, and in that case we might investigate trying to write our own cost analysis stuff to provide to users.

What do you think?

DLorell commented 4 months ago

@thomasjoerg To chime in to what @mattjj has said, this would definitely be very useful for my team right now. We rely on estimated flop counts for debugging and optimization purposes. Can this XLA bug be fixed?

cheshire commented 4 months ago

@olegshyshkov could you take a look?

olegshyshkov commented 3 months ago

I would wager what's going on here is that XLA:CPU is lowering to something like a customcall for which HloCostAnalysis doesn't work.

This is exactly where the -1 is coming from:

https://github.com/openxla/xla/blob/16d9159215c2afa5c9e06efdc8e87bacfec239bb/xla/service/hlo_cost_analysis.cc#L1206

After all the lowering, the original dot operation became a custom-call. Unfortunately, there is no way to reason about performance of a given call, it can be anything. Some libraries that we call could potentially provide a Cost Model for their functions, but none of them do at the moment.

The information about the original instruction is also lost after all the lowering. There are only bits of information in the metadata

metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/google/home/mattjj/packages/jax/tjoerg.py" source_line=5}, backend_config={"outer_dimension_partitions":[]}

We could, potentially, extract matmuls dimension and apply the same math as in HandleDot, but there are two big downsides:

  1. Parsing metadata is only a partial and very brittle solution
  2. The estimate will not give anything close to the real performance, because it will treat all flavours of matmuls the same

Right now, XLA:GPU Cost Model works reasonably well for fusions that go to Emitter (transpose, reduce, elementwise), but doesn't work for matmuls, yet.

cheshire commented 2 months ago

Unfortunately, there is no way to reason about performance of a given call, it can be anything

Well, there is, we know exactly what custom calls are actually matmuls: we've lowered them to those in the first place.

mattjj commented 2 weeks ago

Ping on this, any updated thoughts?

balancap commented 2 weeks ago

Nothing new on the front, but I noticed the same thing for FP8 matmuls (which have the particular aspect of potentially fusing quite a few pre/post processing ops: https://github.com/google/jax/issues/22313)

Is there a canonical way cost_analysis metadata could be specified in a custom-call HLO op? As part of the backend_config dict? To me that would have 2 benefits: