ROCm / HIPIFY

HIPIFY: Convert CUDA to Portable C++ Code
https://rocm.docs.amd.com/projects/HIPIFY/en/latest/
MIT License
503 stars 70 forks source link

[HIPIFY] hipify-clang misbehaves in the presence of preprocessor directives #25

Closed ChrisKitching closed 4 years ago

ChrisKitching commented 6 years ago

Consider the following CUDA program:

__global__ void axpy_kernel(float a, float* x, float* y) {
    y[threadIdx.x] = a * x[threadIdx.x];
}

void axpy(float a, float* x, float* y) {
    axpy_kernel<<<1, 4>>> (a, x, y);
#ifdef SOME_MACRO
    axpy_kernel<<<1, 4>>> (a, x, y);
#endif
}

The result of hipifying this, if you don't pass -DSOME_MACRO is:

#include <hip/hip_runtime.h>
__global__ void axpy_kernel(float a, float* x, float* y) {
    y[hipThreadIdx_x] = a * x[hipThreadIdx_x];
}

void axpy(float a, float* x, float* y) {
    hipLaunchKernelGGL(axpy_kernel, dim3(1), dim3(4), 0, 0, a, x, y);
#ifdef SOME_MACRO
    axpy_kernel<<<1, 4>>> (a, x, y);
#endif

}

Respecting conditional macros isn't the right thing to do with this sort of mechanised refactoring - what you really want to do is walk the entire tree applying your refactor, regardless of preprocessor conditionals.

This is going to present a relatively nasty obstacle to people with complicated CUDA programs they want to translate...

emankov commented 6 years ago

Preprocessor is needed to expand macroses and hipify them also. The obstacle is known, we tried a few approaches, including search of all macro conditions and calling hipify-clang recurrently on the same source with setting an appropriate define, but it turned out very complicated with a real performance drop. Actually we need a bit strange with hipifyication: we need preprocessor and at the same time we need the source code untouched by preprocessor.

ChrisKitching commented 6 years ago

Preprocessor is needed to expand macroses and hipify them also.

Can you give an example of a situation when you need to have a macro expanded?

Why isn't it enough to just look at the source code, fix all the macro definitions, and move on? When do you need the preprocessed source?

emankov commented 6 years ago

Why isn't it enough to just look at the source code, fix all the macro definitions, and move on?

Preprocessor directives are not included in AST.

Can you give an example of a situation when you need to have a macro expanded? When do you need the preprocessed source?

CUDA_8.0/CUDASamples/common/inc/helper_cuda.h:

#define checkCudaErrors(val)           check ( (val), #val, __FILE__, __LINE__ )

CUDA_8.0\include\driver_types.h:

#define cudaEventDisableTiming              0x02  /**< Event will not record timing data */

CUDA_8.0/CUDASamples/6_Advanced/concurrentKernels/concurrentKernels.cu:

#include <helper_cuda.h>
...
cudaEvent_t *kernelEvent;
kernelEvent = (cudaEvent_t *)malloc(nkernels * sizeof(cudaEvent_t));
for (int i = 0; i < nkernels; i++)
{
  checkCudaErrors(cudaEventCreateWithFlags(&(kernelEvent[i]), cudaEventDisableTiming));
}
ChrisKitching commented 6 years ago

So I've noticed that hipify is inlining some macros in a quite destructive way. Consider:

// A handy macro for asserting that a CUDA API call succeeds.
#define CHECK(cmd) \
{\
    cudaError_t error = cmd;\
    if (error != cudaSuccess) { \
        fprintf(stderr, "error: '%s'(%d) at %s:%d\n", cudaGetErrorString(error), error,__FILE__, __LINE__); \
        exit(EXIT_FAILURE);\
    }\
}

void foo() {
    cudaDeviceProp props;

    // Get cuda device information, and crash if something fails.
    CHECK(cudaGetDeviceProperties(&props, 0));
}

hipify translates that into (comments mine, of course):

// A correctly hipified definition of the check macro. It now does the same error checking it
// did for CUDA, but for hip! :D
#define CHECK(cmd) \
{\
    hipError_t error = cmd;\
    if (error != hipSuccess) { \
        fprintf(stderr, "error: '%s'(%d) at %s:%d\n", hipGetErrorString(error), error,__FILE__, __LINE__); \
        exit(EXIT_FAILURE);\
    }\
}

void foo() {
    hipDeviceProp_t props;

    // What is this? This doesn't invoke the CHECK macro. This will silently discard any hip
    // error. `hipify` has changed the semantics, even though it correctly translated that macro.
    // `hipify` could've just left the CHECK() macro in place around the hipGetDeviceProperties
    // call and everything would've been perfect.
    hipGetErrorString(hipGetDeviceProperties(&props, 0));
}
ChrisKitching commented 6 years ago

Relatedly: it'd be nice if hipGetErrorString had the [[nodiscard]] attribute when compiling in C++17 mode - although I suspect you guys don't support C++17 yet.

emankov commented 6 years ago

Relatedly: it'd be nice if hipGetErrorString had the [[nodiscard]] attribute when compiling in C++17 mode - although I suspect you guys don't support C++17 yet.

Well, HIP is being built by HCC, which is based on ToT clang, which ver. is 5.0 for now; hipify-clang is being built also by clang, which ver. might be (with your changes) from 3.8 up to ToT 5.0. [[nodiscard]] attribute is supported by clang since 3.9. We can try it, I think it should work: 1. change option -std=c++11 to -std=c++17 in hipify-clang's cmake; 2. add [[nodiscard]] attribute to hipGetErrorString function in hcc_detail/hip_runtime_api.h and 3. build HIP with hipify-clang (by setting -DHIPIFY_CLANG_LLVM_DIR to cmake) from scratch.

bensander commented 6 years ago

Seems like a fine addition if we guard it appropriately (so it only gets used on C++17). hip_runtime_api.h provides a C API so needs to be compileable with vanilla C compilers such as gcc.

From: Evgeny Mankov [mailto:notifications@github.com] Sent: Wednesday, October 18, 2017 2:47 PM To: ROCm-Developer-Tools/HIP HIP@noreply.github.com Cc: Subscribed subscribed@noreply.github.com Subject: Re: [ROCm-Developer-Tools/HIP] [HIPIFY] hipify-clang misbehaves in the presence of preprocessor directives (#207)

Relatedly: it'd be nice if hipGetErrorString had the [[nodiscard]] attribute when compiling in C++17 mode - although I suspect you guys don't support C++17 yet.

Well, HIP is being built by HCC, which is based on ToT clang, which ver. is 5.0 for now; hipify-clang is being built also by clang, which ver. might be (with your changes) from 3.8 up to ToT 5.0. [[nodiscard]] attribute is supported by clang since 3.9. We can try it, I think it should work: 1. change option -std=c++11 to -std=c++17 in hipify-clang's cmake; 2. add [[nodiscard]] attribute to hipGetErrorString function in hcc_detail/hip_runtime_api.h and 3. build HIP with hipify-clang (by setting -DHIPIFY_CLANG_LLVM_DIR to cmake) from scratch.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHubhttps://github.com/ROCm-Developer-Tools/HIP/issues/207#issuecomment-337706905, or mute the threadhttps://github.com/notifications/unsubscribe-auth/ACYSAktrzwjAiHKr8XO-zQwYRkipgkMcks5stlWrgaJpZM4P6Upb.

ChrisKitching commented 6 years ago

Seems like a fine addition if we guard it appropriately (so it only gets used on C++17). hip_runtime_api.h provides a C API so needs to be compileable with vanilla C compilers such as gcc.

A preprocessor check for the value of __cplusplus being >= 201703L seems like it should work on absolutely every compiler ever.

ChrisKitching commented 6 years ago

I've been able to resolve virtually all preprocessor-related hipify problems that aren't to do with conditional macros. By being careful about how you pick your SourceLocations, it's possible to cope with all but the most contrived use of macros in kernel launches (see testcases).

The idea here is to choose between reading from the expansion location or the spelling location depending on whether a preprocessor macro expansion sits astride the expansion location of the thing you want. If the target range contains macros, you can just copy it verbatim from the expansion location. If it is contained within another macro, and the borders of that macro do not align with the borders of the AST element you're trying to read, you have to inline it. Otherwise you can update it in the macro definition itself.

As for handling conditional macros properly, the best I've come up with so far is to move the work that's currently done with preprocessor callbacks into a Lexer-time FrontendAction. This allows you to run a lexer in raw mode and translate identifiers as they are being lexed. The practical effect of this is that it will correctly apply identifier rename transformations inside preprocessor-deleted blocks, but (of course) won't do anything that is done at the AST-level, such as kernel launches.

ChrisKitching commented 6 years ago

Hang on emankov: ROCm-Developer-Tools/HIP#235 fixed a bunch of problems relating to the preprocessor, but the example given in the original report still fails. Hipify is still unable to properly handle preprocessor-pruned code. I'm working on it locally, but we're not quite there yet. Review the testcases I added to see exactly which cases got fixed.

But we're closer :smile:

emankov commented 6 years ago

the example given in the original report

Which of them?

ChrisKitching commented 6 years ago

Which of them?

Both of the example cases from the original report still fail. As I explained in my pull request and commit messages, my work fixes handling of unconditional macros.

You can see examples of the sorts of situations that ROCm-Developer-Tools/HIP#235 fixed here: https://github.com/ROCm-Developer-Tools/HIP/blob/094b2b9b0503c1e2935863a1d596d1045b71e7e4/tests/hipify-clang/axpy.cu#L6-L12

https://github.com/ROCm-Developer-Tools/HIP/blob/094b2b9b0503c1e2935863a1d596d1045b71e7e4/tests/hipify-clang/axpy.cu#L43-L56

emankov commented 6 years ago

Ok, I misunderstand, thanks.

gargrahul commented 5 years ago

@emankov I am closing this issue. Please reopen it in case we still need to address something.

emankov commented 5 years ago

For now, hipify handles all preprocessor interactions except false conditionals correctly.

ToDo: handle false conditionals as well.

Example:

__global__ void axpy_kernel(float a, float* x, float* y) {
    y[threadIdx.x] = a * x[threadIdx.x];
}

void axpy(float a, float* x, float* y) {
#ifdef SOME_MACRO
    axpy_kernel<<<1, 4>>> (a, y, x);
#else
    axpy_kernel<<<1, 4>>> (a, x, y);
#endif
}

Current hipification:

#include <hip/hip_runtime.h>
__global__ void axpy_kernel(float a, float* x, float* y) {
    y[threadIdx.x] = a * x[threadIdx.x];
}

void axpy(float a, float* x, float* y) {
#ifdef SOME_MACRO
    axpy_kernel<<<1, 4>>> (a, y, x);
#else
    hipLaunchKernelGGL(axpy_kernel, dim3(1), dim3(4), 0, 0, a, x, y);
#endif
}

The needed hipification:

#include <hip/hip_runtime.h>
__global__ void axpy_kernel(float a, float* x, float* y) {
    y[threadIdx.x] = a * x[threadIdx.x];
}

void axpy(float a, float* x, float* y) {
#ifdef SOME_MACRO
    hipLaunchKernelGGL(axpy_kernel, dim3(1), dim3(4), 0, 0, a, y, x);
#else
    hipLaunchKernelGGL(axpy_kernel, dim3(1), dim3(4), 0, 0, a, x, y);
#endif
}
emankov commented 5 years ago

Found a way to handle all preprocessor's conditional blocks, but it needs changes in clang. After the successful applying the changes in trunk clang, this issue will be eliminated, but the fix will be available only in the upcoming LLVM release. I hope the needed changes will be integrated into LLVM 10.0. Unfortunately, I didn't find a way to fix the issue without touching clang; going to send changes in clang on review next week after finishing testing.

emankov commented 5 years ago