turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.2k stars 235 forks source link

Q4 cache CUDA API calls fails to compile on ROCm HIP #361

Closed kzha0 closed 4 months ago

kzha0 commented 4 months ago

The new Q4 cache feature of ExLlamaV2 introduced some function calls from the CUDA DEVICE API which are currently unsupported on HIP, specifically: __shfl_down_sync __shfl_xor_sync __hmax2

Reference More info on which functions are currently supported by HIP here: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html

I did some digging and found that _sync versions of those functions can be implemented in HIP by simply substituting the function with its non-sync counterpart (i.e. from __shfl_xor_sync to just __shfl_xor).

As for the missing __hmax2, since the HIP API supports __half2 types and __hmax operations, a wrapper function can be made that uses __hmax on each of the __half2.a and __half2.y parts, and combining these to create a pseudo-hmax2 function

I hacked a fix which I dropped in at the top of the exllamav2/exllamav2_ext/cuda/cache.cu file to get ExLlamaV2 0.0.14 to compile successfully.

...
// Temporary wrapper for missing CUDA functions not supported in HIP
#ifndef __hmax2 
// `hmax2` implementation that uses `hmax` function
__device__ half2 __hmax2(half2 a, half2 b)
{
    half2 result;
    result.x = __hmax(a.x, b.x);
    result.y = __hmax(a.y, b.y);
    return result;
}
#endif

// Define equivalent functions for __shfl_down_sync and __shfl_xor_sync in HIP
// (Assuming they are not directly available in HIP)
#ifndef __shfl_down_sync
#define __shfl_down_sync(mask, var, delta, width) __shfl_down(var, delta, width) // substitutes __shfl_down_sync with non-sync alternative __shfl_down
#endif

#ifndef __shfl_xor_sync
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) // substitutes __shfl_xor_sync with non-sync alternative __shfl_xor
#endif
...

Though I'm not sure which in the file structure this kind of wrapper/utility should be placed in.

turboderp commented 4 months ago

Thanks for pointing this out. Compilation should be fixed with the latest commit.