openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.62k stars 409 forks source link

Fusion of Fp8 quantization and amax reduction kernels #7432

Open wenscarl opened 10 months ago

wenscarl commented 10 months ago

For transformer models with small to medium-sized gemms, the advantages of using fp8 cublasLt gemms may be overshadowed by the additional computational overhead introduced by memory loads in the quantization and amax reduction kernels. The incorporation of Fp8 gemm introduces an increase in memory reads for each matmul, with a total of six reads (three for rhs/lhs each). In contrast, in the absence of Fp8, there are only two memory reads per matmul. The kernels involved in quantization and amax reduction are memory-bound, making it advantageous to fuse them.

The reproducer is a simple input -> dropout -> dense -> gelu ->dense pattern which is common in the MLP of transformers. From the dumped HLO, it's evident that not only the weights' the quantization(fusion.103) and amax reduction(fusion.35) paths are not fused but also activations(e.g. fusion.92 and fusion.65). @kaixih @reedwm @nluehr @philipphack

To generate HLO dump: TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_reduction_epilogue_fusion=false --xla_dump_hlo_as_html --xla_dump_to=/tmp/generated --xla_dump_hlo_pass_re=.*" python test.py

### test.py
import time
from functools import partial
from typing import Any, Callable, Iterable, Optional, Sequence, Union

import jax
import jax.numpy as jnp
import optax

from flax import linen as nn
from flax.training.train_state import TrainState
from jax.experimental.pjit import pjit

# Type annotations
Array = jnp.ndarray
DType = jnp.dtype
PRNGKey = jnp.ndarray
Shape = Iterable[int]
Initializer = Callable[[PRNGKey, Shape, DType], Array]

# Note, in the jax examples, we use the bf16 in the mixed precision. This is
# different from the fp16 in the TF examples.
dtype = jnp.bfloat16
DenseLayer = nn.DenseGeneral
dg_args = {'dot_general_cls': nn.Fp8DotGeneralOp}

class BasicMLP(nn.Module):
  """Feed-forward network in Transformer layer"""
  hidden_size: int
  ffn_hidden_size: int
  kernel_init: Initializer = nn.initializers.variance_scaling(
      1.0, 'fan_in', 'truncated_normal')
  dtype: Any = jnp.bfloat16

  @nn.compact
  def __call__(self, inputs):
    x = DenseLayer(self.ffn_hidden_size, kernel_init=self.kernel_init,
                   dtype=dtype, name='wi', **dg_args)(inputs)
    x = nn.gelu(x)

    output = DenseLayer(self.hidden_size, kernel_init=self.kernel_init,
                        dtype=dtype, name='wo', **dg_args)(x)
    return output

class BasicTransformer(nn.Module):
  hidden_size: int
  ffn_hidden_size: int

  def setup(self):
    self.ln = nn.LayerNorm()
    self.mlp = BasicMLP(
        hidden_size=self.hidden_size,
        ffn_hidden_size=self.ffn_hidden_size,
    )

  def __call__(self, inputs, attention_mask=None):
    x = self.ln(inputs)
    x = self.mlp(x)
    return x

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384

def run_benchmark():
  key = jax.random.PRNGKey(12)
  x_shape = (sequence_length, batch_size, hidden_size)
  x_data = jax.random.uniform(key, shape=x_shape, dtype=dtype)
  y_data = jax.random.uniform(key, shape=x_shape, dtype=dtype)

  timing_iters = 20
  warmup_iters = 20

  basic_transformer = BasicTransformer(
      hidden_size,
      ffn_hidden_size,
  )

  init_var = basic_transformer.init(key, x_data)
  opt = optax.adam(learning_rate=0.1)
  ts_args = {'tx': opt, 'apply_fn': basic_transformer.apply, 'params': init_var}
  state = TrainState.create(**ts_args)

  def step_fn(state, x, labels):
    def loss_fn(vars, x, labels):
      y = state.apply_fn(vars, x)
      loss = jnp.mean(jnp.square(y - labels))
      return loss

    grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1])
    loss, grads = grad_fn(state.params, x_data, y_data)
    state = state.apply_gradients(grads=grads[0])
    return state, loss

  pjit_train_step_fn = pjit(step_fn)

  # Warmup runs
  for _ in range(warmup_iters):
    state, loss = pjit_train_step_fn(state, x_data, y_data)

  st = time.time()
  for _ in range(timing_iters):
    state, loss = pjit_train_step_fn(state, x_data, y_data)
  elapsed_time = (time.time() - st) / timing_iters * 1000
  print(f"Mean time: {elapsed_time} ms")

run_benchmark()
philipphack commented 10 months ago

It looks like enabling GELU for FP8 GEMMs would allow the fusion of the activation and the calculation of amax into the library call.