alpaka-group / alpaka

Abstraction Library for Parallel Kernel Acceleration :llama:
https://alpaka.readthedocs.io
Mozilla Public License 2.0
337 stars 69 forks source link

portable device-side `abort()` or `throw` ? #2258

Open fwyzard opened 2 months ago

fwyzard commented 2 months ago

In a few places in the CMS code we would like to signal an error from the device side to the host code.

Something similar to a device-side assert(false), but with a more meaningful error message and always available (ALPAKA_ASSERT_ACC may not be implemented on some back-ends).

For example, in one case we have something like

#if defined(__CUDACC__) && defined(__CUDA_ARCH__)
#define CMS_DEVICE_THROW(MSG)           \
  {                                     \
    printf("%s\n", (MSG));              \
    __trap();                           \
  }
#elif defined(__HIPCC__) && defined(__HIP_DEVICE_COMPILE__)
#define CMS_DEVICE_THROW(MSG)           \
  {                                     \
    printf("%s\n", (MSG));              \
    abort();                            \
  }
#else
#define CMS_DEVICE_THROW(MSG)           \
  {                                     \
    throw std::runtime_error(MSG);      \
  }
#endif

Does Alpaka already provide something like this ?

If not, would it be interesting to have it in Alpaka ?

SimeonEhrig commented 2 months ago

I talked with @psychocoderHPC . I did a mistake. I was not alpaka, it was PIC which I saw it and it use native traps.

psychocoderHPC commented 2 months ago

If not, would it be interesting to have it in Alpaka ?

IMO this is interesting. A cool feature would be could handle such throws to avoid an application crash but as I know this is not possible. As @SimeonEhrig said we have it in PMacc too so if alpaka would provide these we can remove vendor code in PMacc https://github.com/ComputationalRadiationPhysics/picongpu/blob/d18b6822e3074e55b3eab6d7c4c9024564efea33/include/pmacc/verify.hpp#L64-L72

@fwyzard For HIP we used the Clang intrinsic __builtin_trap(); not sure if this is equal to abort() Take care abort() could introduce additional register usage if I remember correctly but maybe __builtin_trap is increasing the register footprint too.

fwyzard commented 2 months ago

With CUDA 12.3 and nvcc -O2 -g -arch=sm_75 (link)

__global__ void test() {
    __trap();
}

compiles to

{
    trap;
    ret;
}

while

__global__ void test() {
    assert(0);
}

compiles to

{
    mov.u64         %rd1, $str;
    cvta.global.u64 %rd2, %rd1;
    mov.u64         %rd3, $str$1;
    cvta.global.u64 %rd4, %rd3;
    mov.u64         %rd5, __unnamed_1;
    cvta.global.u64 %rd6, %rd5;
    { // callseq 0, 0
        st.param.b64    [param0+0], %rd2;
        st.param.b64    [param1+0], %rd4;
        st.param.b32    [param2+0], 4;
        st.param.b64    [param3+0], %rd6;
        st.param.b64    [param4+0], 1;
        call.uni __assertfail, (param0, param1, param2, param3, param4);
    } // callseq 0
    ret;
}

CUDA does not have abort(); there is cooperative_groups::details::abort() that expands either to __trap() or to __assert(0) (if extra debug information is enabled).


With ROCm 5.7 and hipcc -O2 -g --offload-arch=gfx90a (link)

__global__ void test() {
    __trap();
}

compiles to

test1():                              ; @test1()
        s_trap 2

while

__global__ void test() {
    assert(0);
}

compiles to

test2():                              ; @test2()
        s_add_u32 flat_scratch_lo, s6, s9
        s_addc_u32 flat_scratch_hi, s7, 0
        s_add_u32 s0, s0, s9
        s_addc_u32 s1, s1, 0
        s_mov_b64 s[8:9], s[4:5]
        s_mov_b32 s32, 0
        s_getpc_b64 s[6:7]
        s_add_u32 s6, s6, __assert_fail@rel32@lo+4
        s_addc_u32 s7, s7, __assert_fail@rel32@hi+12
        s_swappc_b64 s[30:31], s[6:7]```

plus a very long definition for __assert_fail.

mehmetyusufoglu commented 1 month ago

How these new macros can be tested? As far as I understand something similar to CMS_DEVICE_THROW is going to be called inside the kernel and would be captured by rtCheckLastError (which is calling ::cudaGetLastError(); etc...) ? I mean catching the throw.