For transformer models with small to medium-sized gemms, the advantages of using fp8 cublasLt gemms may be overshadowed by the additional computational overhead introduced by memory loads in the quantization and amax reduction kernels. The incorporation of Fp8 gemm introduces an increase in memory reads for each matmul, with a total of six reads (three for rhs/lhs each). In contrast, in the absence of Fp8, there are only two memory reads per matmul. The kernels involved in quantization and amax reduction are memory-bound, making it advantageous to fuse them.
The reproducer is a simple input -> dropout -> dense -> gelu ->dense pattern which is common in the MLP of transformers.
From the dumped HLO, it's evident that not only the weights' the quantization(fusion.103) and amax reduction(fusion.35) paths are not fused but also activations(e.g. fusion.92 and fusion.65). @kaixih @reedwm @nluehr @philipphack
For transformer models with small to medium-sized gemms, the advantages of using fp8 cublasLt gemms may be overshadowed by the additional computational overhead introduced by memory loads in the quantization and amax reduction kernels. The incorporation of Fp8 gemm introduces an increase in memory reads for each matmul, with a total of six reads (three for rhs/lhs each). In contrast, in the absence of Fp8, there are only two memory reads per matmul. The kernels involved in quantization and amax reduction are memory-bound, making it advantageous to fuse them.
The reproducer is a simple input -> dropout -> dense -> gelu ->dense pattern which is common in the MLP of transformers. From the dumped HLO, it's evident that not only the weights' the quantization(fusion.103) and amax reduction(fusion.35) paths are not fused but also activations(e.g. fusion.92 and fusion.65). @kaixih @reedwm @nluehr @philipphack
To generate HLO dump:
TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_reduction_epilogue_fusion=false --xla_dump_hlo_as_html --xla_dump_to=/tmp/generated --xla_dump_hlo_pass_re=.*" python test.py