intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
128 stars 37 forks source link

[GEMM-perf] matmul is slower when one input needs to be transposed #1795

Closed mgrabban closed 3 days ago

mgrabban commented 1 month ago

I find that matmul(X, Y) is ~4X slower when either X or Y needs to be transposed.

So I have a matmul kernel that is similar to the one in triton tutorial here.

That kernel is launched from this code

def fused_mul_add(X, Y, b, transpose_x, transpose_y):
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel_with_block_pointers[grid](
        X, Y, b, Z,
        M, N, K,
        Xstride0, Xstride1,
        Wstride0, Wstride1,
        Z.stride(0), Z.stride(1),
        BIAS_REQD=b is not None,
    )

    return Z

Note that the strides of X or Y are switched (e.g. Xstride0, Xstride1 = X.stride(1), X.stride(0)) if it needs to be transposed.

I notice ff neither needs to be transposed, performance is similar to PyTorch's matmul perf but when either needs to be transposed (so that strides are switched for that input), performance is 4X slower.

This does not happen on CUDA devices. So can you please look into making it efficient for XPU devices as well?

vlad-penkin commented 1 month ago

@mgrabban thanks for the feedback. Could please provide information on you runtime environment:

mgrabban commented 1 month ago

@mgrabban thanks for the feedback. Could please provide information on you runtime environment:

  • GPU HW Model. . Please note that all matmul performance optimizations are only available for the PVC as of now.

I am doing this on PVC (Intel GPU Max 1550).

  • Agama Driver version. Please note that all matmul performance optimizations are only available with the latest Rolling Driver.

My Agama version is 950.4

  • Pytorch or IPEX version or commit id. Please note that regular IPEX is not supported, we are at the final stages of deprecating dependency on the "special IPEX test proxy" and switching fully to the Upstream PyTorch

I am using the PyTorch/IPEX installed using script inside scripts folder

I am using oneAPI/2024.2.0

vlad-penkin commented 1 month ago

Could you please retest with the

To build Upstream PyTorch from source run the following script.

./scripts/compile-pytorch-ipex.sh --pytorch --upstream-pytorch --source

Our Tutorials code still have import intel_extension_for_pytorch line. You can either comment it out or install the dummy no-op ipex using this script:

from os import chdir, makedirs
from tempfile import TemporaryDirectory
from subprocess import run

with TemporaryDirectory() as tmpdir:
    pkg = "intel_extension_for_pytorch"
    chdir(tmpdir)
    makedirs(pkg, exist_ok=True)
    files = {
        f"{pkg}/__init__.py": "",
        "setup.py": (
            "from setuptools import setup, find_packages\n"
            f"setup(name='{pkg}', version='2', packages=find_packages())"
        ),
        "project.toml": (
            "[build-system]\n"
            "requires = [\"setuptools\", \"wheel\"]\n"
            "build-backend = \"setuptools.build_meta\""
        )
    }
    for file, content in files.items():
        with open(file, "w") as f:
            f.write(content)
    cmds = [
        f"pip uninstall -y {pkg}",
        "pip install build",
        "python -m build .",
        f"pip install dist/{pkg}-2-py3-none-any.whl"
    ]
    for cmd in cmds:
        run(cmd.split(), check=True)
mgrabban commented 1 month ago

@vlad-penkin the pytorch-ipex installation script keeps changing. Yesterday I tried your command, it installs but the matmul run was failing due to ipex import. I did comment it out.

Today the install itself fails. I tried ./scripts/compile-pytorch-ipex.sh --upstream-pytorch --source --venv And it gave this error

CMake Error at third_party/kineto/libkineto/src/plugin/xpupti/CMakeLists.txt:23 (find_package):
  By not providing "FindPti.cmake" in CMAKE_MODULE_PATH this project has
  asked CMake to find a package configuration file provided by "Pti", but
  CMake did not find one.

  Could not find a package configuration file provided by "Pti" with any of
  the following names:

    PtiConfig.cmake
    pti-config.cmake

  Add the installation prefix of "Pti" to CMAKE_PREFIX_PATH or set "Pti_DIR"
  to a directory containing one of the above files.  If "Pti" provides a
  separate development package or SDK, be sure it has been installed.

Are you able to run matmul/triton benchmarck.py from your end?

mgrabban commented 1 month ago

The installation issue is now fixed but timing is now broken so triton perf time is showing as 0.0. I think this is the reason WARNING:root:Wall time is used instead of elapsed_time (not supported). The timing measurements could be innacurate.

vlad-penkin commented 1 month ago

@mgarban thanks for the update!

See below my notes:

  1. You are seeing the warning, because pytorch you are using does not support XPUEvent elapsed_time feature. To enable it you need to build pytorch with the additional PR's recommended by us - ./scripts/compile-pytorch-ipex.sh --upstream-pytorch --venv
  2. To build upstream pytorch you need to install and activate matching PTI. it's no longer optional for upstream pytorch build
  3. For more details see the discussion on a similar topic in:
    • 1974

mgrabban commented 1 month ago

@vlad-penkin I'm now able to run and get perf data as shown below

{'torch_inf': 0.15876160562038422,
 'torch_train': 0.42427361011505127,
 'triton_inf': 0.1633344143629074,
 'triton_train': 1.8272528648376465}

As you can see, the issue is not resolved: inference involving matmul(A, B) is performant while training that additionally involves matmul(A, B^T) is not.

arunjose696 commented 3 weeks ago

@mgrabban , what are the sizes of Matrices you are using. I could not run triton_inf or triton_train as they were not shared. However I tried running the matmul kernel in triton tutorials with and without transposing both inputs a and b for various matrix sizes.

I used this code to launch my kernel, It is just slightly modified version of your code except I do just a multiply instead of fused_mul_add

def matmul(X, Y,transpose_x,transpose_y,  activation=""):

    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Ystride0, Ystride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Ystride0, Ystride1 = Y.stride(0), Y.stride(1)

    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        X, Y, Z,  #
        M, N, K,  #
        Xstride0, Xstride1 ,  #
        Ystride0, Ystride1,  #
        Z.stride(0), Z.stride(1),  #
        ACTIVATION=activation  #
    )
    return Z

And below are my results for different matrix sizes

<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

M | N | K | A*B (timings) | A_transposed*B((timings) | A*B_transposed(timings) -- | -- | -- | -- | -- | -- 256 | 256 | 256 | 1.318964 | 0.907858 | 1.226405 384 | 384 | 384 | 3.118012 | 2.131894 | 2.900774 512 | 512 | 512 | 5.785247 | 3.892625 | 5.412005 640 | 640 | 640 | 9.077008 | 5.217835 | 7.710117 768 | 768 | 768 | 13.291809 | 7.600417 | 11.168265 896 | 896 | 896 | 18.055299 | 10.394843 | 15.239896 1024 | 1024 | 1024 | 15.391941 | 8.935934 | 14.143069 1152 | 1152 | 1152 | 20.20116 | 11.388735 | 18.375286 1280 | 1280 | 1280 | 23.831273 | 13.951251 | 21.836236 1408 | 1408 | 1408 | 17.213304 | 8.707603 | 14.907655 1536 | 1536 | 1536 | 20.06133 | 10.370532 | 17.148773 1664 | 1664 | 1664 | 24.393493 | 12.220038 | 20.981069 1792 | 1792 | 1792 | 27.560273 | 14.140419 | 23.638617 1920 | 1920 | 1920 | 22.512367 | 11.050912 | 18.981677 2048 | 2048 | 2048 | 24.232494 | 12.344698 | 20.752645 2176 | 2176 | 2176 | 26.653839 | 13.926401 | 23.222386 2304 | 2304 | 2304 | 24.732246 | 12.112852 | 20.724194 2432 | 2432 | 2432 | 25.813591 | 13.186987 | 22.266819 2560 | 2560 | 2560 | 28.185633 | 14.555469 | 24.545318 2688 | 2688 | 2688 | 27.394669 | 13.298907 | 23.200646 2816 | 2816 | 2816 | 29.298933 | 14.379667 | 24.988221 2944 | 2944 | 2944 | 28.007605 | 13.306519 | 23.660148 3072 | 3072 | 3072 | 30.752535 | 14.676327 | 25.870064 3200 | 3200 | 3200 | 29.507959 | 13.703122 | 24.627964 3328 | 3328 | 3328 | 29.311299 | 14.481528 | 24.880887 3456 | 3456 | 3456 | 29.818425 | 13.93255 | 24.705676 3584 | 3584 | 3584 | 31.112594 | 14.856676 | 26.305472 3712 | 3712 | 3712 | 31.907325 | 15.681895 | 27.125861 3840 | 3840 | 3840 | 33.08352 | 15.447832 | 27.721634 3968 | 3968 | 3968 | 30.94293 | 14.636378 | 25.89539 4096 | 4096 | 4096 | 32.431981 | 15.491316 | 27.195817

I also tried modifying the kernel in tutorial to a fused_multiply_add and still get similar numbers, I don't see a performance degradation when one of the inputs is transposed instead I still see a slight performance increase. Could you recheck if you are using latest Agama drivers and pytorch from upstream, and run the kernel in this tutorial with launch script. And let me know if the performance degradation still exists just for running matrix multiplication alone. As there might be possibly other functionalities in triton_inf or triton_train which might have a unexpected effect.

These are my hw details

LIBIGC1_VERSION=1.0.17193.16-950 LEVEL_ZERO_VERSION=1.3.30049.10-950 AGAMA_VERSION=950 GPU_DEVICE=Intel(R) Data Center GPU Max 1100

alexbaden commented 3 weeks ago

@mgrabban could you provide us with the cached Triton-generated code for both runs (transpose and w/out transpose?) The easiest way to do it is to delete your Triton cache (rm -rf ~/.triton/cache) and then run both kernels. You should see ~5 folders in the cache dir. Two will contain several files ending in .ttir, .llir, .spv etc - one for the transpose and one w/out. Can you copy both folders here so we can examine the IR? That will also let us run your generated code verbatim on our systems.

mgrabban commented 3 weeks ago

Just for reference: a single file reproducer was provided to @alexbaden