ROCm / rocWMMA

rocWMMA
https://rocm.docs.amd.com/projects/rocWMMA/
MIT License
90 stars 26 forks source link

Using rocwmma with pytorch #239

Closed fileaccent closed 1 year ago

fileaccent commented 1 year ago

I want to be able to convert a cuda code containing wmma into hip. I have unit tests done and it works. I hope to integrate this code into pytorch. When I executed "python setup.py install", I found that all the architectures were added when the code was compiled, so the execution reported an error. Because rocwmma does not support gfx1030. What should I do to avoid this error? Can I just compile for a certain architecture?

This is the content of the setup.py file:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from pathlib import Path
import os

workspace_dir = Path(os.path.dirname(os.path.abspath(__file__)))

setup(
    name="fused_attn",
    ext_modules=[
        CUDAExtension(
            name="fused_attn",
            sources=[str(workspace_dir / "src" / "fused_attn_extention.cu")],
            include_dirs=[str(workspace_dir / "include")],
            extra_compile_args=[
                "-O3", 
                "-std=c++20", 
                "--offload-arch=gfx90a",
                "-I/opt/rocm/include",
                "-I/opt/rocm/hip/include"
                ],
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)

The following is part of the error report:(I specified the architecture, but pytorch still adds all architectures.)

[1/1] /opt/rocm/bin/hipcc  -I/data/zhaorong/code/fused-attention/include -I/opt/conda/lib/python3.8/site-packages/torch/include -I/opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/lib/python3.8/site-packages/torch/include/THH -I/opt/rocm/include -I/opt/rocm/miopen/include -I/opt/rocm/hip/include -I/opt/conda/include/python3.8 -c -c /data/zhaorong/code/fused-attention/src/fused_attn_extention.hip -o /data/zhaorong/code/fused-attention/build/temp.linux-x86_64-cpython-38/data/zhaorong/code/fused-attention/src/fused_attn_extention.o -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -std=c++20 

--offload-arch=gfx90a 

-I/opt/rocm/include -I/opt/rocm/hip/include -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=fused_attn -D_GLIBCXX_USE_CXX11_ABI=1 

--amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --amdgpu-target=gfx1030 

-fno-gpu-rdc
FAILED: /data/zhaorong/code/fused-attention/build/temp.linux-x86_64-cpython-38/data/zhaorong/code/fused-attention/src/fused_attn_extention.o 
/opt/rocm/bin/hipcc  -I/data/zhaorong/code/fused-attention/include -I/opt/conda/lib/python3.8/site-packages/torch/include -I/opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/lib/python3.8/site-packages/torch/include/THH -I/opt/rocm/include -I/opt/rocm/miopen/include -I/opt/rocm/hip/include -I/opt/conda/include/python3.8 -c -c /data/zhaorong/code/fused-attention/src/fused_attn_extention.hip -o /data/zhaorong/code/fused-attention/build/temp.linux-x86_64-cpython-38/data/zhaorong/code/fused-attention/src/fused_attn_extention.o -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -std=c++20 --offload-arch=gfx90a -I/opt/rocm/include -I/opt/rocm/hip/include -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1013"' -DTORCH_EXTENSION_NAME=fused_attn -D_GLIBCXX_USE_CXX11_ABI=1 --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --amdgpu-target=gfx1030 -fno-gpu-rdc
Warning: The --amdgpu-target option has been deprecated and will be removed in the future.  Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future.  Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future.  Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future.  Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future.  Use --offload-arch instead.

In file included from /data/zhaorong/code/fused-attention/src/fused_attn_extention.hip:4:
In file included from /data/zhaorong/code/fused-attention/include/fused_attn_hip.cuh:6:
In file included from /opt/rocm-5.4.0/include/rocwmma/rocwmma.hpp:31:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/io_config.hpp:29:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/broadcast.hpp:29:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/types.hpp:339:
/opt/rocm-5.4.0/include/rocwmma/internal/types_ext.hpp:328:40: error: no matching conversion for static_cast from 'const rocwmma::hfloat16_t' (aka 'const __half') to 'rocwmma::float16_t' (aka '_Float16')
        return static_cast<hfloat16_t>(static_cast<float16_t>(x) * static_cast<float16_t>(y));
                                       ^~~~~~~~~~~~~~~~~~~~~~~~~
/opt/rocm/hip/include/hip/amd_detail/../../../../include/hip/amd_detail/amd_hip_fp16.h:233:13: note: candidate function
            operator __half_raw() const { return __half_raw{data}; }
            ^
/opt/rocm/hip/include/hip/amd_detail/../../../../include/hip/amd_detail/amd_hip_fp16.h:235:13: note: candidate function
            operator __half_raw() const volatile

Environment: rocm: 5.4 ubuntu: 20.04 python: 3.8 pytorch: 1.12.1 GPU: MI210 rocwmma-dev: 0.7.0.50400-72~20.04

cgmillette commented 1 year ago

Hi @fileaccent thanks for reaching out!

Since you are attempting to integrate rocWMMA into another infrastructure, this would require some investigation as to how pytorch sets up the compiler and development environment. Pytorch appears to be ultimately responsible for setting targets flags and integrating other code. Modifying these settings would be in the configuration of your pytorch build.

That being said - the first line with the HIPCC call has the following:

-D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1

This is likely related to the compilation issue you are seeing about the conversion between __half and _Float16 types. Under the hood we support incoming __half data, but we require ability to convert to _Float16 which is native to AMD cards. If this is done in another library (which is entirely possible) then you need to expose those conversions.

What I would recommend is to check the pytorch documentation to see whether you can adjust your pytorch configuration to build only for your intended targets, and set the above flags to 0's. I'm sure the pytorch devs would also be accommodating to any question you may have.

Just FYI - looks like most of this configuration with the above is done in a 'gloo' cmake file within pytorch: pytorch/tools/amd_build/build_amd.py gloo/cmake/Dependencies.cmake gloo/cmake/Hip.cmake

Cheers,

--Chris

fileaccent commented 1 year ago

I solved this problem later and made a record: (The method is not particularly formal. If there are other solutions, you can also share them.)

In hip we can add the "--no-offload-arch" option to prevent compilation for some architectures. Here I block all architectures except gfx90a. Below is a setup.py code for reference.

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from pathlib import Path
import os

workspace_dir = Path(os.path.dirname(os.path.abspath(__file__)))

setup(
    name="fused_attn",
    ext_modules=[
        CUDAExtension(
            name="fused_attn",
            sources=[str(workspace_dir / "src" / "fused_attn_extention.cu")],
            include_dirs=[str(workspace_dir / "include")],
            extra_compile_args=[
                "-O3", 
                "-std=c++20", 
                "-I/opt/rocm/include",
                "-I/opt/rocm/hip/include",
                "--no-offload-arch=gfx1030",
                "--no-offload-arch=gfx900",
                "--no-offload-arch=gfx906",
                "--no-offload-arch=gfx908"
                ],
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)

In this way, errors will not be reported due to compilation of unsupported architectures.

You may also encounter the following errors:

In file included from /data/zhaorong/code/fused-attention/src/fused_attn_extention.hip:4:
In file included from /data/zhaorong/code/fused-attention/include/fused_attn_hip.cuh:6:
In file included from /opt/rocm-5.4.0/include/rocwmma/rocwmma.hpp:31:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/io_config.hpp:29:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/broadcast.hpp:29:
In file included from /opt/rocm-5.4.0/include/rocwmma/internal/types.hpp:339:
/opt/rocm-5.4.0/include/rocwmma/internal/types_ext.hpp:328:40: error: no matching conversion for static_cast from 'const rocwmma::hfloat16_t' (aka 'const __half') to 'rocwmma::float16_t' (aka '_Float16')
        return static_cast<hfloat16_t>(static_cast<float16_t>(x) * static_cast<float16_t>(y));

You need to modify the source file of rocwmma to solve this problem. Modify the three functions that reported errors to the following form.

__host__ inline hfloat16_t operator*(const hfloat16_t& x, const hfloat16_t& y)
    {
        float16_t mid1 = *(float16_t *)(void *)(&x);
        float16_t mid2 = *(float16_t *)(void *)(&y);
        mid1 = mid1 * mid2;
        return *(hfloat16_t *)(void *)&mid1;
    }

    __host__ inline hfloat16_t operator+(const hfloat16_t& x, const hfloat16_t& y)
    {
        float16_t mid1 = *(float16_t *)(void *)(&x);
        float16_t mid2 = *(float16_t *)(void *)(&y);
        mid1 = mid1 + mid2;
        return *(hfloat16_t *)(void *)&mid1;
    }

    __host__ inline hfloat16_t& operator+=(hfloat16_t& x, const hfloat16_t& y)
    {
        float16_t mid1 = *(float16_t *)(void *)(&x);
        float16_t mid2 = *(float16_t *)(void *)(&y);
        mid1 = mid1 + mid2;
        return x = *(hfloat16_t *)(void *)&mid1;
    }

Now you should install this pytorch extension. The above method is not formal. If there is a better method, please tell me. Thank you so much.

cgmillette commented 1 year ago

Hi @fileaccent, Happy to see that you've achieved a solution.

I will take this experience and see what we can do to make rocWMMA aware of the flag -DHIP_NO_HALF_OPERATORS=1 -DHIP_NO_HALF_CONVERSIONS=1

This way in the future the source modification shouldn't be necessary.

Thank you very much for your feedback