jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Human readable `as_text` for `jax.stages.Compiled` / `HloModule` #22270

Open balancap opened 4 months ago

balancap commented 4 months ago

jax.stages.Lowered is providing a simple stable HLO readable text output. e.g.

module @jit_matmul_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<128x128xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg1: tensor<128x128xf8E5M2> {mhlo.layout_mode = "default"}) -> (tensor<128x128xf8E5M2> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<128x128xf8E4M3FN>, tensor<128x128xf8E5M2>) -> tensor<128x128xf8E5M2>
    return %0 : tensor<128x128xf8E5M2>
  }
}

for a simple FP8 matmul

def matmul_fn(a_fp8, b_fp8):
    return jax.lax.dot(a_fp8, b_fp8)

On the other hand, after optimization passes and compilation, the HloModule as_text is much harder to decipher for a user:

'HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[128,128]{1,0}, f8e5m2[128,128]{1,0})->f8e5m2[128,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="4e843a02b85ac5317998ba4c43b95c90"}\n\n%wrapped_transpose_computation (param_0: f8e5m2[128,128]) -> f8e5m2[128,128] {\n  %param_0 = f8e5m2[128,128]{1,0} parameter(0)\n  ROOT %transpose.1.1 = f8e5m2[128,128]{1,0} transpose(f8e5m2[128,128]{1,0} %param_0), dimensions={1,0}, metadata={op_name="b"}\n}\n\nENTRY %main.4 (Arg_0.1.0: f8e4m3fn[128,128], Arg_1.2.0: f8e5m2[128,128]) -> f8e5m2[128,128] {\n  %constant_1 = f32[] constant(1), metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}\n  %Arg_1.2.0 = f8e5m2[128,128]{1,0} parameter(1), metadata={op_name="b"}\n  %Arg_0.1.0 = f8e4m3fn[128,128]{1,0} parameter(0), metadata={op_name="a"}\n  %wrapped_transpose = f8e5m2[128,128]{1,0} fusion(f8e5m2[128,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation, metadata={op_name="b"}\n  %cublas-gemm.1.0 = (f8e5m2[128,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[128,128]{1,0} %Arg_0.1.0, f8e5m2[128,128]{1,0} %wrapped_transpose, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"4","lhs_stride":"16384","rhs_stride":"16384","grad_x":false,"grad_y":false},"force_earliest_schedule":false}\n  ROOT %get-tuple-element.1 = f8e5m2[128,128]{1,0} get-tuple-element((f8e5m2[128,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0, metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}\n}\n\n'

pprint is helping a bit, but it is still not great.

Why does it matter?

In my special case of FP8 matmuls, XLA has a collection of complex optimization passes fusing (combinations of) pre/post scaling + bias + activation function + amax capture into a single custom call __cublas$lt$matmul$f8. This fusion is critical to obtain good performance, hence an (advanced) user would most likely want to check the generated call by the compiler is as optimal as possible (or if an HLO pattern matching failed).

superbobry commented 4 months ago

Hi @balancap,

just to be clear what the ask is: you want

HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[128,128]{1,0}, f8e5m2[128,128]{1,0})->f8e5m2[128,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="4e843a02b85ac5317998ba4c43b95c90"}

%wrapped_transpose_computation (param_0: f8e5m2[128,128]) -> f8e5m2[128,128] {
  %param_0 = f8e5m2[128,128]{1,0} parameter(0)
  ROOT %transpose.1.1 = f8e5m2[128,128]{1,0} transpose(f8e5m2[128,128]{1,0} %param_0), dimensions={1,0}, metadata={op_name="b"}
}

ENTRY %main.4 (Arg_0.1.0: f8e4m3fn[128,128], Arg_1.2.0: f8e5m2[128,128]) -> f8e5m2[128,128] {
  %constant_1 = f32[] constant(1), metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}
  %Arg_1.2.0 = f8e5m2[128,128]{1,0} parameter(1), metadata={op_name="b"}
  %Arg_0.1.0 = f8e4m3fn[128,128]{1,0} parameter(0), metadata={op_name="a"}
  %wrapped_transpose = f8e5m2[128,128]{1,0} fusion(f8e5m2[128,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation, metadata={op_name="b"}
  %cublas-gemm.1.0 = (f8e5m2[128,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[128,128]{1,0} %Arg_0.1.0, f8e5m2[128,128]{1,0} %wrapped_transpose, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"4","lhs_stride":"16384","rhs_stride":"16384","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  ROOT %get-tuple-element.1 = f8e5m2[128,128]{1,0} get-tuple-element((f8e5m2[128,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0, metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}
}

to be easier to read. Does that sound right?

balancap commented 4 months ago

Indeed, or the user has a way to get it more readable with more detailed bindings of an HLO module, such that you can filter out metadata and extract ops config dictionary. I had to do some hacky text parsing to be able to directly print:

ENTRY %main.24 (Arg_0.1.0: f8e4m3fn[128,128], Arg_1.2.0: f8e4m3fn[128,128], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[128,128,4], f32[]) {
  %constant_2_0 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,128]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[128,128]{1,0} parameter(0)
  %wrapped_transpose = f8e4m3fn[128,128]{1,0} fusion(f8e4m3fn[128,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation
  %cublas-gemm.1.0 = (f32[128,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[128,128]{1,0} %Arg_0.1.0, f8e4m3fn[128,128]{1,0} %wrapped_transpose, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_2_0, /*index=5*/f32[] %constant_2_0), custom_call_target="__cublas$lt$matmul$f8"
backend_cfg:  {
    "operation_queue_id": "0",
    "wait_on_operation_queues": [],
    "gemm_backend_config": {
        "alpha_real": 1,
        "alpha_imag": 0,
        "beta": 0,
        "dot_dimension_numbers": {
            "lhs_contracting_dimensions": [
                "1"
            ],
            "rhs_contracting_dimensions": [
                "1"
            ],
            "lhs_batch_dimensions": [],
            "rhs_batch_dimensions": []
        },
        "precision_config": {
            "operand_precision": [
                "DEFAULT",
                "DEFAULT"
            ],
            "algorithm": "ALG_UNSET"
        },
        "epilogue": "DEFAULT",
        "damax_output": false,
        "selected_algorithm": "2",
        "lhs_stride": "16384",
        "rhs_stride": "16384",
        "grad_x": false,
        "grad_y": false
    },
    "force_earliest_schedule": false
}
  %get-tuple-element.1 = f32[128,128]{1,0} get-tuple-element((f32[128,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0
  %input_reduce_fusion = f32[] fusion(f32[128,128]{1,0} %get-tuple-element.1), kind=kInput, calls=%fused_reduce
  %loop_bitcast_convert_fusion = f8e4m3fn[128,128,4]{2,1,0} fusion(f32[128,128]{1,0} %get-tuple-element.1, f32[] %Arg_4.5.0), kind=kLoop, calls=%fused_bitcast_convert
  ROOT %tuple.23.0 = (f8e4m3fn[128,128,4]{2,1,0}, f32[]) tuple(f8e4m3fn[128,128,4]{2,1,0} %loop_bitcast_convert_fusion, f32[] %input_reduce_fusion)
}

in #22313