NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
250 stars 49 forks source link

performance optimization for moving distributive reduction operation across broadcast dimensions #1660

Open jjsjann123 opened 7 months ago

jjsjann123 commented 7 months ago

I'm putting it here for backlog.

Looking at the example below, the two models produces the same result. By pulling the reduction op ahead of the broadcast+mul, codegen seems to be getting much better performance at the toy example.

Though I'm currently hitting a bug #1659 when trying to scale it up. I'll update the problem when I patch the transpose scheduler.

import torch
import nvfuser
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from torch import Tensor
from typing import Tuple

def model(input, b) -> Tuple[Tensor]:
    tensors = [input, b]
    with nvfuser.FusionDefinition() as fd:
        x = fd.define_tensor((-1, -1, 1), contiguity=(True, True, None), dtype=torch_dtype_to_nvfuser_dtype(input.dtype), stride_order=(2, 1, 0))
        b = fd.define_tensor((-1, 1, -1), contiguity=(True, None, True), dtype=torch_dtype_to_nvfuser_dtype(input.dtype), stride_order=(2, 1, 0))
        T1 = fd.ops.sum(x, 1)
        T2 = fd.ops.sum(b, 1)
        T3 = fd.ops.mul(T1, T2)
        fd.add_output(T3)
    fn = fd.execute(tensors)
    return fn

def model_ref(input, b) -> Tuple[Tensor]:
    tensors = [input, b]
    with nvfuser.FusionDefinition() as fd:
        x = fd.define_tensor((-1, -1, 1), contiguity=(True, True, None), dtype=torch_dtype_to_nvfuser_dtype(input.dtype), stride_order=(2, 1, 0))
        b = fd.define_tensor((-1, 1, -1), contiguity=(True, None, True), dtype=torch_dtype_to_nvfuser_dtype(input.dtype), stride_order=(2, 1, 0))
        T1 = fd.ops.mul(x, b)
        T2 = fd.ops.sum(T1, 1)
        fd.add_output(T2)
    fn = fd.execute(tensors)
    return fn

### Inputs
inputs = [
  torch.randn(8, 4, 1, device="cuda"),
  torch.randn(8, 1, 4, device="cuda"),
]

### Repro code
for i in range(3):
    o = model(*inputs)
    o_ref = model_ref(*inputs)

print (o[0] - o_ref[0])
assert o[0].allclose(o_ref[0])

torch.cuda.profiler.start()
for i in range(3):
    o = model(*inputs)
    o_ref = model_ref(*inputs)
torch.cuda.profiler.stop()
jjsjann123 commented 5 months ago

A similar problem showed up that. With layernorm + amax. Segmenter is breaking the amax into a separate kernel. Where it should have break amax into two reduction kernel, where it will be able to fuse the first half into layernorm -> saving bandwidth on the both kernels.

    out1 = (x + bias) + residual
    out1 = out1.view(-1, w_shape[-1])

    # LayerNorm
    w = ln_weight if not zero_centered_gamma else 1 + ln_weight
    out2 = torch.nn.functional.layer_norm(out1, w_shape, w, ln_bias, eps)

    # Obtain FP8 amax
    amax = torch.amax(torch.abs(out2))
    amax_tensor.fill_(amax)
jjsjann123 commented 5 months ago

A similar problem showed up that. With layernorm + amax. Segmenter is breaking the amax into a separate kernel. Where it should have break amax into two reduction kernel, where it will be able to fuse the first half into layernorm -> saving bandwidth on the both kernels.

    out1 = (x + bias) + residual
    out1 = out1.view(-1, w_shape[-1])

    # LayerNorm
    w = ln_weight if not zero_centered_gamma else 1 + ln_weight
    out2 = torch.nn.functional.layer_norm(out1, w_shape, w, ln_bias, eps)

    # Obtain FP8 amax
    amax = torch.amax(torch.abs(out2))
    amax_tensor.fill_(amax)

Had a quick chat with @naoyam

I think the case where we can fuse kernel is when as long as the reduction tensor matches the reference tensor of the normalization, this should be safe to fuse., which does not apply to the case above. But we could still break the amax into two amax and match the first one with first layernorm.

Need to think more about how this could be automatically figured out... I'm wondering if it makes sense to always decompose a full reduction and try our luck with it. :confused: