ROCm / HIP

HIP: C++ Heterogeneous-Compute Interface for Portability
https://rocmdocs.amd.com/projects/HIP/
MIT License
3.76k stars 539 forks source link

`__half2` implementation does not match CUDA one, resulting in a bug in `__half2float` calls #3290

Closed fxmarty closed 6 months ago

fxmarty commented 1 year ago

Hi,

I have a kernel where __half2float behaves differently with HIP vs CUDA.

Reproduce with

Use this CUDA kernel:

#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <stdio.h>

#define TPB 32  // threads per block

__global__ void distanceKernel(float* d_out) {
    const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;

    printf("Thread id: %d\n", thread_id);

    float scale_back;
    uint32_t t;
    uint32_t t2;

    float scale = 0.035;

    printf("[THREAD %d] scale %f\n", (int)threadIdx.x, scale);
    half2 scale_half2 = __float2half2_rn(scale);

    memcpy(&t, &scale_half2, sizeof(half2));
    printf("[THREAD %d] scale after memcpy: %lu\n", (int)threadIdx.x, (unsigned long)t);

    scale_back = __half2float(scale_half2.x);
    memcpy(&t2, &scale_back, sizeof(float));

    printf("[THREAD %d] int repr scale_back: %lu\n", (int)threadIdx.x, (unsigned long)t2);
    printf("[THREAD %d] scale_back: %f\n", (int)threadIdx.x, scale_back);

    d_out[thread_id] = scale_back;
}

int main() {
    float* d_out = NULL;
    float d_out_cpu[TPB];
    cudaMalloc(&d_out, TPB * sizeof(float));

    printf("Hey\n");
    distanceKernel<<<1, TPB>>>(d_out);

    cudaMemcpy(d_out_cpu, d_out, TPB * sizeof(float), cudaMemcpyDeviceToHost);

    printf("CPU result:\n");
    for (int i = 0; i < TPB; i++) {
        printf("    %f, \n", d_out_cpu[i]);
    }
    printf("\n");

    cudaFree(d_out);
    return 0;
}

Running this kernel on an Nvidia GPU, we rightfully get:

Hey
Thread id: 0
Thread id: 1
Thread id: 2
Thread id: 3
Thread id: 4
Thread id: 5
Thread id: 6
Thread id: 7
Thread id: 8
Thread id: 9
Thread id: 10
Thread id: 11
Thread id: 12
Thread id: 13
Thread id: 14
Thread id: 15
Thread id: 16
Thread id: 17
Thread id: 18
Thread id: 19
Thread id: 20
Thread id: 21
Thread id: 22
Thread id: 23
Thread id: 24
Thread id: 25
Thread id: 26
Thread id: 27
Thread id: 28
Thread id: 29
Thread id: 30
Thread id: 31
[THREAD 0] scale 0.035000
[THREAD 1] scale 0.035000
[THREAD 2] scale 0.035000
[THREAD 3] scale 0.035000
[THREAD 4] scale 0.035000
[THREAD 5] scale 0.035000
[THREAD 6] scale 0.035000
[THREAD 7] scale 0.035000
[THREAD 8] scale 0.035000
[THREAD 9] scale 0.035000
[THREAD 10] scale 0.035000
[THREAD 11] scale 0.035000
[THREAD 12] scale 0.035000
[THREAD 13] scale 0.035000
[THREAD 14] scale 0.035000
[THREAD 15] scale 0.035000
[THREAD 16] scale 0.035000
[THREAD 17] scale 0.035000
[THREAD 18] scale 0.035000
[THREAD 19] scale 0.035000
[THREAD 20] scale 0.035000
[THREAD 21] scale 0.035000
[THREAD 22] scale 0.035000
[THREAD 23] scale 0.035000
[THREAD 24] scale 0.035000
[THREAD 25] scale 0.035000
[THREAD 26] scale 0.035000
[THREAD 27] scale 0.035000
[THREAD 28] scale 0.035000
[THREAD 29] scale 0.035000
[THREAD 30] scale 0.035000
[THREAD 31] scale 0.035000
[THREAD 0] scale after memcpy: 679159931
[THREAD 1] scale after memcpy: 679159931
[THREAD 2] scale after memcpy: 679159931
[THREAD 3] scale after memcpy: 679159931
[THREAD 4] scale after memcpy: 679159931
[THREAD 5] scale after memcpy: 679159931
[THREAD 6] scale after memcpy: 679159931
[THREAD 7] scale after memcpy: 679159931
[THREAD 8] scale after memcpy: 679159931
[THREAD 9] scale after memcpy: 679159931
[THREAD 10] scale after memcpy: 679159931
[THREAD 11] scale after memcpy: 679159931
[THREAD 12] scale after memcpy: 679159931
[THREAD 13] scale after memcpy: 679159931
[THREAD 14] scale after memcpy: 679159931
[THREAD 15] scale after memcpy: 679159931
[THREAD 16] scale after memcpy: 679159931
[THREAD 17] scale after memcpy: 679159931
[THREAD 18] scale after memcpy: 679159931
[THREAD 19] scale after memcpy: 679159931
[THREAD 20] scale after memcpy: 679159931
[THREAD 21] scale after memcpy: 679159931
[THREAD 22] scale after memcpy: 679159931
[THREAD 23] scale after memcpy: 679159931
[THREAD 24] scale after memcpy: 679159931
[THREAD 25] scale after memcpy: 679159931
[THREAD 26] scale after memcpy: 679159931
[THREAD 27] scale after memcpy: 679159931
[THREAD 28] scale after memcpy: 679159931
[THREAD 29] scale after memcpy: 679159931
[THREAD 30] scale after memcpy: 679159931
[THREAD 31] scale after memcpy: 679159931
[THREAD 0] int repr scale_back: 1024417792
[THREAD 1] int repr scale_back: 1024417792
[THREAD 2] int repr scale_back: 1024417792
[THREAD 3] int repr scale_back: 1024417792
[THREAD 4] int repr scale_back: 1024417792
[THREAD 5] int repr scale_back: 1024417792
[THREAD 6] int repr scale_back: 1024417792
[THREAD 7] int repr scale_back: 1024417792
[THREAD 8] int repr scale_back: 1024417792
[THREAD 9] int repr scale_back: 1024417792
[THREAD 10] int repr scale_back: 1024417792
[THREAD 11] int repr scale_back: 1024417792
[THREAD 12] int repr scale_back: 1024417792
[THREAD 13] int repr scale_back: 1024417792
[THREAD 14] int repr scale_back: 1024417792
[THREAD 15] int repr scale_back: 1024417792
[THREAD 16] int repr scale_back: 1024417792
[THREAD 17] int repr scale_back: 1024417792
[THREAD 18] int repr scale_back: 1024417792
[THREAD 19] int repr scale_back: 1024417792
[THREAD 20] int repr scale_back: 1024417792
[THREAD 21] int repr scale_back: 1024417792
[THREAD 22] int repr scale_back: 1024417792
[THREAD 23] int repr scale_back: 1024417792
[THREAD 24] int repr scale_back: 1024417792
[THREAD 25] int repr scale_back: 1024417792
[THREAD 26] int repr scale_back: 1024417792
[THREAD 27] int repr scale_back: 1024417792
[THREAD 28] int repr scale_back: 1024417792
[THREAD 29] int repr scale_back: 1024417792
[THREAD 30] int repr scale_back: 1024417792
[THREAD 31] int repr scale_back: 1024417792
[THREAD 0] scale_back: 0.035004
[THREAD 1] scale_back: 0.035004
[THREAD 2] scale_back: 0.035004
[THREAD 3] scale_back: 0.035004
[THREAD 4] scale_back: 0.035004
[THREAD 5] scale_back: 0.035004
[THREAD 6] scale_back: 0.035004
[THREAD 7] scale_back: 0.035004
[THREAD 8] scale_back: 0.035004
[THREAD 9] scale_back: 0.035004
[THREAD 10] scale_back: 0.035004
[THREAD 11] scale_back: 0.035004
[THREAD 12] scale_back: 0.035004
[THREAD 13] scale_back: 0.035004
[THREAD 14] scale_back: 0.035004
[THREAD 15] scale_back: 0.035004
[THREAD 16] scale_back: 0.035004
[THREAD 17] scale_back: 0.035004
[THREAD 18] scale_back: 0.035004
[THREAD 19] scale_back: 0.035004
[THREAD 20] scale_back: 0.035004
[THREAD 21] scale_back: 0.035004
[THREAD 22] scale_back: 0.035004
[THREAD 23] scale_back: 0.035004
[THREAD 24] scale_back: 0.035004
[THREAD 25] scale_back: 0.035004
[THREAD 26] scale_back: 0.035004
[THREAD 27] scale_back: 0.035004
[THREAD 28] scale_back: 0.035004
[THREAD 29] scale_back: 0.035004
[THREAD 30] scale_back: 0.035004
[THREAD 31] scale_back: 0.035004
CPU result:
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 
    0.035004, 

Running then:

hipify-perl kernel_half.cu -o=kernel_half.hip
hipcc kernel_half.hip
./a.out

on an AMD machine (here MI250), we get:

Hey
Thread id: 0
Thread id: 1
Thread id: 2
Thread id: 3
Thread id: 4
Thread id: 5
Thread id: 6
Thread id: 7
Thread id: 8
Thread id: 9
Thread id: 10
Thread id: 11
Thread id: 12
Thread id: 13
Thread id: 14
Thread id: 15
Thread id: 16
Thread id: 17
Thread id: 18
Thread id: 19
Thread id: 20
Thread id: 21
Thread id: 22
Thread id: 23
Thread id: 24
Thread id: 25
Thread id: 26
Thread id: 27
Thread id: 28
Thread id: 29
Thread id: 30
Thread id: 31
[THREAD 0] scale 0.035000
[THREAD 1] scale 0.035000
[THREAD 2] scale 0.035000
[THREAD 3] scale 0.035000
[THREAD 4] scale 0.035000
[THREAD 5] scale 0.035000
[THREAD 6] scale 0.035000
[THREAD 7] scale 0.035000
[THREAD 8] scale 0.035000
[THREAD 9] scale 0.035000
[THREAD 10] scale 0.035000
[THREAD 11] scale 0.035000
[THREAD 12] scale 0.035000
[THREAD 13] scale 0.035000
[THREAD 14] scale 0.035000
[THREAD 15] scale 0.035000
[THREAD 16] scale 0.035000
[THREAD 17] scale 0.035000
[THREAD 18] scale 0.035000
[THREAD 19] scale 0.035000
[THREAD 20] scale 0.035000
[THREAD 21] scale 0.035000
[THREAD 22] scale 0.035000
[THREAD 23] scale 0.035000
[THREAD 24] scale 0.035000
[THREAD 25] scale 0.035000
[THREAD 26] scale 0.035000
[THREAD 27] scale 0.035000
[THREAD 28] scale 0.035000
[THREAD 29] scale 0.035000
[THREAD 30] scale 0.035000
[THREAD 31] scale 0.035000
[THREAD 0] scale after memcpy: 679159931
[THREAD 1] scale after memcpy: 679159931
[THREAD 2] scale after memcpy: 679159931
[THREAD 3] scale after memcpy: 679159931
[THREAD 4] scale after memcpy: 679159931
[THREAD 5] scale after memcpy: 679159931
[THREAD 6] scale after memcpy: 679159931
[THREAD 7] scale after memcpy: 679159931
[THREAD 8] scale after memcpy: 679159931
[THREAD 9] scale after memcpy: 679159931
[THREAD 10] scale after memcpy: 679159931
[THREAD 11] scale after memcpy: 679159931
[THREAD 12] scale after memcpy: 679159931
[THREAD 13] scale after memcpy: 679159931
[THREAD 14] scale after memcpy: 679159931
[THREAD 15] scale after memcpy: 679159931
[THREAD 16] scale after memcpy: 679159931
[THREAD 17] scale after memcpy: 679159931
[THREAD 18] scale after memcpy: 679159931
[THREAD 19] scale after memcpy: 679159931
[THREAD 20] scale after memcpy: 679159931
[THREAD 21] scale after memcpy: 679159931
[THREAD 22] scale after memcpy: 679159931
[THREAD 23] scale after memcpy: 679159931
[THREAD 24] scale after memcpy: 679159931
[THREAD 25] scale after memcpy: 679159931
[THREAD 26] scale after memcpy: 679159931
[THREAD 27] scale after memcpy: 679159931
[THREAD 28] scale after memcpy: 679159931
[THREAD 29] scale after memcpy: 679159931
[THREAD 30] scale after memcpy: 679159931
[THREAD 31] scale after memcpy: 679159931
[THREAD 0] int repr scale_back: 1176625152
[THREAD 1] int repr scale_back: 1176625152
[THREAD 2] int repr scale_back: 1176625152
[THREAD 3] int repr scale_back: 1176625152
[THREAD 4] int repr scale_back: 1176625152
[THREAD 5] int repr scale_back: 1176625152
[THREAD 6] int repr scale_back: 1176625152
[THREAD 7] int repr scale_back: 1176625152
[THREAD 8] int repr scale_back: 1176625152
[THREAD 9] int repr scale_back: 1176625152
[THREAD 10] int repr scale_back: 1176625152
[THREAD 11] int repr scale_back: 1176625152
[THREAD 12] int repr scale_back: 1176625152
[THREAD 13] int repr scale_back: 1176625152
[THREAD 14] int repr scale_back: 1176625152
[THREAD 15] int repr scale_back: 1176625152
[THREAD 16] int repr scale_back: 1176625152
[THREAD 17] int repr scale_back: 1176625152
[THREAD 18] int repr scale_back: 1176625152
[THREAD 19] int repr scale_back: 1176625152
[THREAD 20] int repr scale_back: 1176625152
[THREAD 21] int repr scale_back: 1176625152
[THREAD 22] int repr scale_back: 1176625152
[THREAD 23] int repr scale_back: 1176625152
[THREAD 24] int repr scale_back: 1176625152
[THREAD 25] int repr scale_back: 1176625152
[THREAD 26] int repr scale_back: 1176625152
[THREAD 27] int repr scale_back: 1176625152
[THREAD 28] int repr scale_back: 1176625152
[THREAD 29] int repr scale_back: 1176625152
[THREAD 30] int repr scale_back: 1176625152
[THREAD 31] int repr scale_back: 1176625152
[THREAD 0] scale_back: 10360.000000
[THREAD 1] scale_back: 10360.000000
[THREAD 2] scale_back: 10360.000000
[THREAD 3] scale_back: 10360.000000
[THREAD 4] scale_back: 10360.000000
[THREAD 5] scale_back: 10360.000000
[THREAD 6] scale_back: 10360.000000
[THREAD 7] scale_back: 10360.000000
[THREAD 8] scale_back: 10360.000000
[THREAD 9] scale_back: 10360.000000
[THREAD 10] scale_back: 10360.000000
[THREAD 11] scale_back: 10360.000000
[THREAD 12] scale_back: 10360.000000
[THREAD 13] scale_back: 10360.000000
[THREAD 14] scale_back: 10360.000000
[THREAD 15] scale_back: 10360.000000
[THREAD 16] scale_back: 10360.000000
[THREAD 17] scale_back: 10360.000000
[THREAD 18] scale_back: 10360.000000
[THREAD 19] scale_back: 10360.000000
[THREAD 20] scale_back: 10360.000000
[THREAD 21] scale_back: 10360.000000
[THREAD 22] scale_back: 10360.000000
[THREAD 23] scale_back: 10360.000000
[THREAD 24] scale_back: 10360.000000
[THREAD 25] scale_back: 10360.000000
[THREAD 26] scale_back: 10360.000000
[THREAD 27] scale_back: 10360.000000
[THREAD 28] scale_back: 10360.000000
[THREAD 29] scale_back: 10360.000000
[THREAD 30] scale_back: 10360.000000
[THREAD 31] scale_back: 10360.000000
CPU result:
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 
    10360.000000, 

which is obviously wrong.

To me, the issue stems from the definition of the __half2 struct in rocm include/hip/amd_detail/amd_hip_fp16.h:

        // BEGIN STRUCT __HALF2
        struct __half2 {
        public:
            union {
                static_assert(
                    sizeof(_Float16_2) == sizeof(unsigned short[2]), "");

                _Float16_2 data;
                struct {
                    unsigned short x;
                    unsigned short y;
                };
            };
// truncated

Compare to the definition in CUDA (though the __CUDA_ALIGN__(4) is obscure):

struct __CUDA_ALIGN__(4) __half2 {
    __half x;
    __half y;
// truncated

Replacing the call

scale_back = __half2float(scale_half2.x);

by scale_back = (float)scale_half2.data.x; or by scale_back = __low2float(scale_half2); solves the issue. But I believe this is a bug, given that the CUDA kernel and HIP kernel behave differently, with no error raised. Maybe the hipifier should handle this case?

Related https://github.com/ROCm-Developer-Tools/HIP/issues/3280

cc @ardfork @cjatin

xinyi-li7 commented 1 year ago

Same here. Want to know if any solutions.

cjatin commented 1 year ago

Thanks for reporting this. Will look into it.

cjatin commented 1 year ago

You are right, the issue is due to half2 using unsigned short x,y , it should be `half x, y`.

At the moment HIP is doing a cast of unsigned short to half, which is causing this issue.

Wil raise a PR to fix it, which might take some time to get to github. meanwhile to fix it you can manually cast it to __half before calling __half2float. Something like __half2float(*reinterpret_cast<__half*>(&scale_half2.x));

emankov commented 1 year ago

Right, for now (I'm not sure that's got to be fixed soon due to backward compatibility), reinterpret_cast<__half*> is the only solution already supported by hipify-clang (#801), but not by hipify-perl.

ppanchad-amd commented 6 months ago

@fxmarty Can you please test with latest ROCm 6.1.0 (HIP 6.1)? If resolved, please close ticket. Thanks!

fxmarty commented 6 months ago

@ppanchad-amd Just tested, this is fixed thank you.