Open balancap opened 4 months ago
Hi @balancap,
just to be clear what the ask is: you want
HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[128,128]{1,0}, f8e5m2[128,128]{1,0})->f8e5m2[128,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="4e843a02b85ac5317998ba4c43b95c90"}
%wrapped_transpose_computation (param_0: f8e5m2[128,128]) -> f8e5m2[128,128] {
%param_0 = f8e5m2[128,128]{1,0} parameter(0)
ROOT %transpose.1.1 = f8e5m2[128,128]{1,0} transpose(f8e5m2[128,128]{1,0} %param_0), dimensions={1,0}, metadata={op_name="b"}
}
ENTRY %main.4 (Arg_0.1.0: f8e4m3fn[128,128], Arg_1.2.0: f8e5m2[128,128]) -> f8e5m2[128,128] {
%constant_1 = f32[] constant(1), metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}
%Arg_1.2.0 = f8e5m2[128,128]{1,0} parameter(1), metadata={op_name="b"}
%Arg_0.1.0 = f8e4m3fn[128,128]{1,0} parameter(0), metadata={op_name="a"}
%wrapped_transpose = f8e5m2[128,128]{1,0} fusion(f8e5m2[128,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation, metadata={op_name="b"}
%cublas-gemm.1.0 = (f8e5m2[128,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[128,128]{1,0} %Arg_0.1.0, f8e5m2[128,128]{1,0} %wrapped_transpose, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"4","lhs_stride":"16384","rhs_stride":"16384","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
ROOT %get-tuple-element.1 = f8e5m2[128,128]{1,0} get-tuple-element((f8e5m2[128,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0, metadata={op_name="jit(matmul_fn)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="" source_line=3}
}
to be easier to read. Does that sound right?
Indeed, or the user has a way to get it more readable with more detailed bindings of an HLO module, such that you can filter out metadata and extract ops config dictionary. I had to do some hacky text parsing to be able to directly print:
ENTRY %main.24 (Arg_0.1.0: f8e4m3fn[128,128], Arg_1.2.0: f8e4m3fn[128,128], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[128,128,4], f32[]) {
%constant_2_0 = f32[] constant(1)
%Arg_4.5.0 = f32[] parameter(4)
%Arg_3.4.0 = f32[] parameter(3)
%Arg_2.3.0 = f32[] parameter(2)
%Arg_1.2.0 = f8e4m3fn[128,128]{1,0} parameter(1)
%Arg_0.1.0 = f8e4m3fn[128,128]{1,0} parameter(0)
%wrapped_transpose = f8e4m3fn[128,128]{1,0} fusion(f8e4m3fn[128,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation
%cublas-gemm.1.0 = (f32[128,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[128,128]{1,0} %Arg_0.1.0, f8e4m3fn[128,128]{1,0} %wrapped_transpose, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_2_0, /*index=5*/f32[] %constant_2_0), custom_call_target="__cublas$lt$matmul$f8"
backend_cfg: {
"operation_queue_id": "0",
"wait_on_operation_queues": [],
"gemm_backend_config": {
"alpha_real": 1,
"alpha_imag": 0,
"beta": 0,
"dot_dimension_numbers": {
"lhs_contracting_dimensions": [
"1"
],
"rhs_contracting_dimensions": [
"1"
],
"lhs_batch_dimensions": [],
"rhs_batch_dimensions": []
},
"precision_config": {
"operand_precision": [
"DEFAULT",
"DEFAULT"
],
"algorithm": "ALG_UNSET"
},
"epilogue": "DEFAULT",
"damax_output": false,
"selected_algorithm": "2",
"lhs_stride": "16384",
"rhs_stride": "16384",
"grad_x": false,
"grad_y": false
},
"force_earliest_schedule": false
}
%get-tuple-element.1 = f32[128,128]{1,0} get-tuple-element((f32[128,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0
%input_reduce_fusion = f32[] fusion(f32[128,128]{1,0} %get-tuple-element.1), kind=kInput, calls=%fused_reduce
%loop_bitcast_convert_fusion = f8e4m3fn[128,128,4]{2,1,0} fusion(f32[128,128]{1,0} %get-tuple-element.1, f32[] %Arg_4.5.0), kind=kLoop, calls=%fused_bitcast_convert
ROOT %tuple.23.0 = (f8e4m3fn[128,128,4]{2,1,0}, f32[]) tuple(f8e4m3fn[128,128,4]{2,1,0} %loop_bitcast_convert_fusion, f32[] %input_reduce_fusion)
}
in #22313
jax.stages.Lowered
is providing a simple stable HLO readable text output. e.g.for a simple FP8 matmul
On the other hand, after optimization passes and compilation, the
HloModule
as_text
is much harder to decipher for a user:pprint
is helping a bit, but it is still not great.Why does it matter?
In my special case of FP8 matmuls, XLA has a collection of complex optimization passes fusing (combinations of) pre/post scaling + bias + activation function + amax capture into a single custom call
__cublas$lt$matmul$f8
. This fusion is critical to obtain good performance, hence an (advanced) user would most likely want to check the generated call by the compiler is as optimal as possible (or if an HLO pattern matching failed).