NVIDIA / cccl

CUDA Core Compute Libraries
https://nvidia.github.io/cccl/
Other
1.3k stars 165 forks source link

[FEA]: Support fancy iterators in cuda.parallel #2479

Open gevtushenko opened 1 month ago

gevtushenko commented 1 month ago

Is this a duplicate?

Area

General CCCL

Is your feature request related to a problem? Please describe.

Compared to existing solutions in PyTorch and CuPy, one of the distinguishing features of cuda.parallel is flexibility. Part of that flexibility is coming from support of user-defined data types and operators. But compared to CUDA C++ solution, cuda.parallel API is still limited. We are still missing fancy iterators and cache modified iterators in cuda.parallel.

Describe the solution you'd like

Given that fancy iterators support might require rather invasive changes to cuda.parallel and CCCL/c libraries, we should design fancy iterators before introducing more algorithms to cuda.parallel.

Describe alternatives you've considered

No response

Additional context

No response

rwgk commented 1 month ago

Keeping track of a link I got from Georgii: https://github.com/NVIDIA/cccl/pull/2335

rwgk commented 1 month ago

Tracking for easy future reference:

https://numba.pydata.org/numba-doc/dev/user/jitclass.html

Limitations

rwgk commented 1 month ago

I'm currently doing work on this branch:

https://github.com/rwgk/cccl/tree/python_random_access_iterators

Last commit https://github.com/rwgk/cccl/commit/d1c4816f8f3391c97e6fd32a89d45785615f6ea1 — Use TransformRAI to implement constant, counting, arbitrary RAIs.

Current thinking:

rwgk commented 1 month ago

Some points for the upcoming meeting, to capture the state of my understanding:

scalar_with_certain_dtype = RAI_object[random_access_array_index]

https://github.com/NVIDIA/cccl/blob/e5229f2c7509ced5a830be4ae884d9e1639e8951/c/parallel/src/reduce.cu#L395-L417

That code formats C++ code to be compiled with nvrtc a little further down.

At this stage I'm very conservative: I'm aiming for minimal changes to Georgii's code to get the desired behavior passing


Non-goal at this stage: cache-modified C++ input iterator. — I still needs to learn what exactly is needed. In a follow-on step (to the work above), I want to solve this also in a conservative fashion.

When that is done I want to take a step back to review:

rwgk commented 1 month ago

Tracking progress:

I just created this Draft PR: https://github.com/NVIDIA/cccl/pull/2595, currently @ commit 5ba7a0f413123cc05e6eb9f3690e8b571659c670

Copy-pasting the current PR description:


Goal: Any unary_op(distance) that can be compiled by numba can be passes as reduce_into(d_in)

Current status:

        def other_unary_op(distance):
            permutation = (4, 2, 0, 3, 1)
            return permutation[distance % len(permutation)]
        def input_unary_op(distance):
            return 2 * other_unary_op(distance)

The error is:

E       numba.core.errors.TypingError: Failed in cuda mode pipeline (step: nopython frontend)
E       Untyped global name 'other_unary_op': Cannot determine Numba type of <class 'function'>
E       
E       File "tests/test_reduce.py", line 117:
E               def input_unary_op(distance):
E                   return 2 * other_unary_op(distance)
E                   ^
gevtushenko commented 1 month ago

Let's take a look at some of the API that we need in cuda.parallel.itertools. Below you can find a proof-of-concept implementation for count, repeat, map, and cache.

num_items = 4
output_array = numba.cuda.device_array(num_items, dtype=np.int32)

# Returns 42 N times
r = repeat(42)
parallel_algorithm(r, num_items, output_array)
print("expect: 42 42 42 42;  get: ", " ".join([str(x) for x in output_array.copy_to_host()]))

# Returns an integer sequence starting at 42
c = count(42)
parallel_algorithm(c, num_items, output_array)
print("expect: 42 43 44 45;  get: ", " ".join([str(x) for x in output_array.copy_to_host()]))

# Multiplies 42 (coming from repeat) by 2
def mult(x):
    return x * 2

mult_42_by_2 = map(r, mult)
parallel_algorithm(mult_42_by_2, num_items, output_array)
print("expect: 84 84 84 84;  get: ", " ".join([str(x) for x in output_array.copy_to_host()]))

# Adds 10 to result of multiplication of repeat by 2
def add(x):
    return x + 10

mult_42_by_2_plus10 = map(mult_42_by_2, add)
parallel_algorithm(mult_42_by_2_plus10, num_items, output_array)
print("expect: 94 94 94 94;  get: ", " ".join([str(x) for x in output_array.copy_to_host()]))

# Same as above, but for count
mult_count_by_2 = map(c, mult)
parallel_algorithm(mult_count_by_2, num_items, output_array)
print("expect: 84 86 88 90;  get: ", " ".join([str(x) for x in output_array.copy_to_host()]))

mult_count_by_2_and_add_10 = map(mult_count_by_2, add)
parallel_algorithm(mult_count_by_2_and_add_10, num_items, output_array)
print("expect: 94 96 98 100; get:", " ".join([str(x) for x in output_array.copy_to_host()]))

# Example of how combinational iterators can wrap pointer in a generic way
input_array = numba.cuda.to_device(np.array([4, 3, 2, 1], dtype=np.int32))
ptr = pointer(input_array) # TODO this transformation should be hidden on the transform implementation side
parallel_algorithm(ptr, num_items, output_array)
print("expect:  4  3  2 1  ; get:", " ".join([str(x) for x in output_array.copy_to_host()]))

input_array = numba.cuda.to_device(np.array([4, 3, 2, 1], dtype=np.int32))
ptr = pointer(input_array) # TODO this transformation should be hidden on the transform implementation side
tptr = map(ptr, mult)
parallel_algorithm(tptr, num_items, output_array)
print("expect:  8  6  4 2  ; get:", " ".join([str(x) for x in output_array.copy_to_host()]))

# Example of caching iterator
streamed_input = cache(input_array, 'stream')
parallel_algorithm(streamed_input, num_items, output_array)
print("expect:  4  3  2 1  ; get:", " ".join([str(x) for x in output_array.copy_to_host()]))
main.py ```python from numba.core import cgutils from llvmlite import ir from numba import types from numba.core.typing import signature from numba.core.extending import intrinsic, overload from numba.core.errors import NumbaTypeError import ctypes import numba import numba.cuda import numpy as np class RawPointer: def __init__(self, ptr): # TODO Showcasing the case of int32, need dtype with at least primitive types, ideally any numba type self.val = ctypes.c_void_p(ptr) self.ltoirs = [numba.cuda.compile(RawPointer.pointer_advance, sig=numba.types.void(numba.types.CPointer(numba.types.uint64), numba.types.int32), output='ltoir'), numba.cuda.compile(RawPointer.pointer_dereference, sig=numba.types.int32(numba.types.CPointer(numba.types.CPointer(numba.types.int32))), output='ltoir')] self.prefix = 'pointer' def pointer_advance(this, distance): this[0] = this[0] + numba.types.uint64(4 * distance) # TODO Showcasing the case of int32, need dtype with at least primitive types, ideally any numba type def pointer_dereference(this): return this[0][0] def host_address(self): return ctypes.byref(self.val) def size(self): return 8 # TODO should be using numba for user-defined types support def alignment(self): return 8 # TODO should be using numba for user-defined types support def pointer(container): return RawPointer(container.device_ctypes_pointer.value) @intrinsic def ldcs(typingctx, base): signature = types.int32(types.CPointer(types.int32)) def codegen(context, builder, sig, args): int32 = ir.IntType(32) int32_ptr = int32.as_pointer() ldcs_type = ir.FunctionType(int32, [int32_ptr]) ldcs = ir.InlineAsm(ldcs_type, "ld.global.cs.b32 $0, [$1];", "=r, l") return builder.call(ldcs, args) return signature, codegen class CacheModifiedPointer: def __init__(self, ptr): # TODO Showcasing the case of int32, need dtype with at least primitive types, ideally any numba type self.val = ctypes.c_void_p(ptr) self.ltoirs = [numba.cuda.compile(CacheModifiedPointer.cache_advance, sig=numba.types.void(numba.types.CPointer(numba.types.uint64), numba.types.int32), output='ltoir'), numba.cuda.compile(CacheModifiedPointer.cache_dereference, sig=numba.types.int32(numba.types.CPointer(numba.types.CPointer(numba.types.int32))), output='ltoir')] self.prefix = 'cache' def cache_advance(this, distance): this[0] = this[0] + numba.types.uint64(4 * distance) # TODO Showcasing the case of int32, need dtype with at least primitive types, ideally any numba type def cache_dereference(this): return ldcs(this[0]) def host_address(self): return ctypes.byref(self.val) def size(self): return 8 # TODO should be using numba for user-defined types support def alignment(self): return 8 # TODO should be using numba for user-defined types support def cache(container, modifier='stream'): if modifier != 'stream': raise NotImplementedError("Only stream modifier is supported") return CacheModifiedPointer(container.device_ctypes_pointer.value) class ConstantIterator: def __init__(self, val): # TODO Showcasing the case of int32, need dtype with at least primitive types, ideally any numba type thisty = numba.types.CPointer(numba.types.int32) self.val = ctypes.c_int32(val) self.ltoirs = [numba.cuda.compile(ConstantIterator.constant_int32_advance, sig=numba.types.void(thisty, numba.types.int32), output='ltoir'), numba.cuda.compile(ConstantIterator.constant_int32_dereference, sig=numba.types.int32(thisty), output='ltoir')] self.prefix = 'constant_int32' def constant_int32_advance(this, _): pass def constant_int32_dereference(this): return this[0] def host_address(self): # TODO should use numba instead for support of user-defined types return ctypes.byref(self.val) def size(self): return ctypes.sizeof(self.val) # TODO should be using numba for user-defined types support def alignment(self): return ctypes.alignment(self.val) # TODO should be using numba for user-defined types support def repeat(value): return ConstantIterator(value) class CountingIterator: def __init__(self, count): # TODO Showcasing the case of int32, need dtype thisty = numba.types.CPointer(numba.types.int32) self.count = ctypes.c_int32(count) self.ltoirs = [numba.cuda.compile(CountingIterator.count_int32_advance, sig=numba.types.void(thisty, numba.types.int32), output='ltoir'), numba.cuda.compile(CountingIterator.count_int32_dereference, sig=numba.types.int32(thisty), output='ltoir')] self.prefix = 'count_int32' def count_int32_advance(this, diff): this[0] += diff def count_int32_dereference(this): return this[0] def host_address(self): return ctypes.byref(self.count) def size(self): return ctypes.sizeof(self.count) # TODO should be using numba for user-defined types support def alignment(self): return ctypes.alignment(self.count) # TODO should be using numba for user-defined types support def count(offset): return CountingIterator(offset) def map(it, op): def source_advance(it_state_ptr, diff): pass def make_advance_codegen(name): retty = types.void statety = types.CPointer(types.int8) distty = types.int32 def codegen(context, builder, sig, args): state_ptr, dist = args fnty = ir.FunctionType(ir.VoidType(), (ir.PointerType(ir.IntType(8)), ir.IntType(32))) fn = cgutils.get_or_insert_function(builder.module, fnty, name) builder.call(fn, (state_ptr, dist)) return signature(retty, statety, distty), codegen def advance_codegen(func_to_overload, name): @intrinsic def intrinsic_impl(typingctx, it_state_ptr, diff): return make_advance_codegen(name) @overload(func_to_overload, target='cuda') def impl(it_state_ptr, diff): def impl(it_state_ptr, diff): return intrinsic_impl(it_state_ptr, diff) return impl def source_dereference(it_state_ptr): pass def make_dereference_codegen(name): retty = types.int32 statety = types.CPointer(types.int8) def codegen(context, builder, sig, args): state_ptr, = args fnty = ir.FunctionType(ir.IntType(32), (ir.PointerType(ir.IntType(8)),)) fn = cgutils.get_or_insert_function(builder.module, fnty, name) return builder.call(fn, (state_ptr,)) return signature(retty, statety), codegen def dereference_codegen(func_to_overload, name): @intrinsic def intrinsic_impl(typingctx, it_state_ptr): return make_dereference_codegen(name) @overload(func_to_overload, target='cuda') def impl(it_state_ptr): def impl(it_state_ptr): return intrinsic_impl(it_state_ptr) return impl def make_op_codegen(name): retty = types.int32 valty = types.int32 def codegen(context, builder, sig, args): val, = args fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32),)) fn = cgutils.get_or_insert_function(builder.module, fnty, name) return builder.call(fn, (val,)) return signature(retty, valty), codegen def op_codegen(func_to_overload, name): @intrinsic def intrinsic_impl(typingctx, val): return make_op_codegen(name) @overload(func_to_overload, target='cuda') def impl(val): def impl(val): return intrinsic_impl(val) return impl advance_codegen(source_advance, f"{it.prefix}_advance") dereference_codegen(source_dereference, f"{it.prefix}_dereference") op_codegen(op, op.__name__) class TransformIterator: def __init__(self, it, op): self.it = it # TODO support row pointers self.op = op self.prefix = f'transform_{it.prefix}_{op.__name__}' self.ltoirs = it.ltoirs + [numba.cuda.compile(TransformIterator.transform_advance, sig=numba.types.void(numba.types.CPointer(numba.types.char), numba.types.int32), output='ltoir', abi_info={"abi_name": f"{self.prefix}_advance"}), numba.cuda.compile(TransformIterator.transform_dereference, sig=numba.types.int32(numba.types.CPointer(numba.types.char)), output='ltoir', abi_info={"abi_name": f"{self.prefix}_dereference"}), numba.cuda.compile(op, sig=numba.types.int32(numba.types.int32), output='ltoir')] def transform_advance(it_state_ptr, diff): source_advance(it_state_ptr, diff) # just a function call def transform_dereference(it_state_ptr): return op(source_dereference(it_state_ptr)) # just a function call def host_address(self): return self.it.host_address() # TODO support stateful operators def size(self): return self.it.size() # TODO fix for stateful op def alignment(self): return self.it.alignment() # TODO fix for stateful op return TransformIterator(it, op) def parallel_algorithm(d_in, num_items, d_out): ltoirs = [ltoir[0] for ltoir in d_in.ltoirs] LTOIRArrayType = ctypes.c_char_p * len(ltoirs) LTOIRSizesArrayType = ctypes.c_int * len(ltoirs) ltoir_pointers = [ctypes.c_char_p(ltoir) for ltoir in ltoirs] ltoir_sizes = [len(ltoir) for ltoir in ltoirs] input_ltoirs = LTOIRArrayType(*ltoir_pointers) input_ltoir_sizes = LTOIRSizesArrayType(*ltoir_sizes) bindings = ctypes.CDLL('build/libkernel.so') bindings.host_code.argtypes = [ctypes.c_int, # size ctypes.c_int, # alignment ctypes.c_void_p, # input_pointer ctypes.c_char_p, # prefix ctypes.c_int, # num_items ctypes.c_void_p, # output_pointer ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_int), ctypes.c_int] output_pointer = output_array.device_ctypes_pointer.value bindings.host_code(d_in.size(), d_in.alignment(), d_in.host_address(), d_in.prefix.encode('utf-8'), num_items, output_pointer, input_ltoirs, input_ltoir_sizes, len(ltoirs)) # User Code num_items = 4 output_array = numba.cuda.device_array(num_items, dtype=np.int32) ## Repeat r = repeat(42) parallel_algorithm(r, num_items, output_array) print("expect: 42 42 42 42; get: ", " ".join([str(x) for x in output_array.copy_to_host()])) ## Count c = count(42) parallel_algorithm(c, num_items, output_array) print("expect: 42 43 44 45; get: ", " ".join([str(x) for x in output_array.copy_to_host()])) ## Transform def mult(x): return x * 2 mult_42_by_2 = map(r, mult) parallel_algorithm(mult_42_by_2, num_items, output_array) print("expect: 84 84 84 84; get: ", " ".join([str(x) for x in output_array.copy_to_host()])) def add(x): return x + 10 mult_42_by_2_plus10 = map(mult_42_by_2, add) parallel_algorithm(mult_42_by_2_plus10, num_items, output_array) print("expect: 94 94 94 94; get: ", " ".join([str(x) for x in output_array.copy_to_host()])) mult_count_by_2 = map(c, mult) parallel_algorithm(mult_count_by_2, num_items, output_array) print("expect: 84 86 88 90; get: ", " ".join([str(x) for x in output_array.copy_to_host()])) mult_count_by_2_and_add_10 = map(mult_count_by_2, add) parallel_algorithm(mult_count_by_2_and_add_10, num_items, output_array) print("expect: 94 96 98 100; get:", " ".join([str(x) for x in output_array.copy_to_host()])) input_array = numba.cuda.to_device(np.array([4, 3, 2, 1], dtype=np.int32)) ptr = pointer(input_array) # TODO this transformation should be hidden on the transform implementation side parallel_algorithm(ptr, num_items, output_array) print("expect: 4 3 2 1 ; get:", " ".join([str(x) for x in output_array.copy_to_host()])) input_array = numba.cuda.to_device(np.array([4, 3, 2, 1], dtype=np.int32)) ptr = pointer(input_array) # TODO this transformation should be hidden on the transform implementation side tptr = map(ptr, mult) parallel_algorithm(tptr, num_items, output_array) print("expect: 8 6 4 2 ; get:", " ".join([str(x) for x in output_array.copy_to_host()])) streamed_input = cache(input_array, 'stream') parallel_algorithm(streamed_input, num_items, output_array) print("expect: 4 3 2 1 ; get:", " ".join([str(x) for x in output_array.copy_to_host()])) ```
CMakeLists.txt ```cmake cmake_minimum_required(VERSION 3.21) project(CCCL_C_Parallel LANGUAGES CUDA CXX) add_library(kernel SHARED kernel.cpp) set_property(TARGET kernel PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET kernel PROPERTY CXX_STANDARD 20) find_package(CUDAToolkit REQUIRED) # TODO Use static versions of cudart, nvrtc, and nvJitLink target_link_libraries(kernel PRIVATE CUDA::cudart CUDA::nvrtc CUDA::nvJitLink CUDA::cuda_driver) ```
kernel.cpp ```c++ #include #include #include #include #include #include #include void check(nvrtcResult result) { if (result != NVRTC_SUCCESS) { throw std::runtime_error(std::string("NVRTC error: ") + nvrtcGetErrorString(result)); } } void check(CUresult result) { if (result != CUDA_SUCCESS) { const char* str = nullptr; cuGetErrorString(result, &str); throw std::runtime_error(std::string("CUDA error: ") + str); } } void check(nvJitLinkResult result) { if (result != NVJITLINK_SUCCESS) { throw std::runtime_error(std::string("nvJitLink error: ") + std::to_string(result)); } } extern "C" void host_code(int iterator_size, int iterator_alignment, void *pointer_to_cpu_bytes_storing_value, const char *prefix, int num_items, void *pointer_to_to_gpu_memory, const void **input_ltoirs, const int* input_ltoir_sizes, int num_input_ltoirs) { std::string deref = std::string("#define DEREF ") + prefix + "_dereference\n" + "extern \"C\" __device__ int DEREF(const char *state); \n"; std::string adv = std::string("#define ADV ") + prefix + "_advance\n" + "extern \"C\" __device__ void ADV(char *state, int distance); \n"; std::string state = std::string("struct __align__(") + std::to_string(iterator_alignment) + R"XXX() iterator_t { // using iterator_category = cuda::std::random_access_iterator_tag; // TODO add include to libcu++ using value_type = int; using difference_type = int; using pointer = int; using reference = int; __device__ value_type operator*() const { return DEREF(data); } __device__ iterator_t& operator+=(difference_type diff) { ADV(data, diff); return *this; } __device__ value_type operator[](difference_type diff) const { return *(*this + diff); } __device__ iterator_t operator+(difference_type diff) const { iterator_t result = *this; result += diff; return result; } char data[)XXX" + std::to_string(iterator_size) + "]; };\n"; // CUB kernel accepts an iterator and does some of the following operations on it std::string kernel_source = deref + adv + state + R"XXX( extern "C" __global__ void device_code(int num_items, iterator_t iterator, int *pointer) { iterator_t it = iterator + blockIdx.x; for (int i = 0; i < num_items; i++) { pointer[i] = it[i]; } } )XXX"; nvrtcProgram prog; const char *name = "test_kernel"; nvrtcCreateProgram(&prog, kernel_source.c_str(), name, 0, nullptr, nullptr); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); const int cc_major = deviceProp.major; const int cc_minor = deviceProp.minor; const std::string arch = std::string("-arch=sm_") + std::to_string(cc_major) + std::to_string(cc_minor); const char* args[] = { arch.c_str(), "-rdc=true", "-dlto" }; const int num_args = sizeof(args) / sizeof(args[0]); std::size_t log_size{}; nvrtcResult compile_result = nvrtcCompileProgram(prog, num_args, args); check(nvrtcGetProgramLogSize(prog, &log_size)); std::unique_ptr log{new char[log_size]}; check(nvrtcGetProgramLog(prog, log.get())); if (log_size > 1) { std::cerr << log.get() << std::endl; } check(compile_result); std::size_t ltoir_size{}; check(nvrtcGetLTOIRSize(prog, <oir_size)); std::unique_ptr ltoir{new char[ltoir_size]}; check(nvrtcGetLTOIR(prog, ltoir.get())); check(nvrtcDestroyProgram(&prog)); nvJitLinkHandle handle; const char* lopts[] = {"-lto", arch.c_str()}; check(nvJitLinkCreate(&handle, 2, lopts)); check(nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, ltoir.get(), ltoir_size, name)); for (int ltoir_id = 0; ltoir_id < num_input_ltoirs; ltoir_id++) { check(nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, input_ltoirs[ltoir_id], input_ltoir_sizes[ltoir_id], name)); } check(nvJitLinkComplete(handle)); std::size_t cubin_size{}; check(nvJitLinkGetLinkedCubinSize(handle, &cubin_size)); std::unique_ptr cubin{new char[cubin_size]}; check(nvJitLinkGetLinkedCubin(handle, cubin.get())); check(nvJitLinkDestroy(&handle)); CUlibrary library; CUkernel kernel; cuLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0); check(cuLibraryGetKernel(&kernel, library, "device_code")); void *pointer_to_cpu_bytes_storing_pointer_to_gpu_memory = &pointer_to_to_gpu_memory; void *kernel_args[] = { &num_items, pointer_to_cpu_bytes_storing_value, pointer_to_cpu_bytes_storing_pointer_to_gpu_memory }; check(cuLaunchKernel((CUfunction)kernel, 1, 1, 1, 1, 1, 1, 0, 0, kernel_args, nullptr)); check(cuStreamSynchronize(0)); check(cuLibraryUnload(library)); } ```

Before this proof-of-concept is merged, we have to address the following:

Future work that can be addressed after initial support of primitive types:

rwgk commented 3 weeks ago

Discussion items for sync when Georgii returns from PTO:

https://github.com/rwgk/cccl/tree/georgii_poc_2479

https://github.com/rwgk/numba/tree/misc_doc_fixes — By-product of systematically working through Numba documentation top-to-bottom.

georgii_poc_2479 main.py

@register_jitable fixes the kind of error I asked about on October 17 (PR #2595):

Unanswered question slack cccl-tm: @numba.cuda.reduce vs cub::DeviceReduce

Why is the cuda.cooperative implementation so different (is it?) from cuda.parallel?

Is there anything we can use or learn from cupyx.jit.cub, cupyx.jit.thrust?

rwgk commented 2 weeks ago

I managed to combine Georgii's POC ConstantIterator and CountingIterator with the code I had under #2595:

It's pretty simple, and I got most of the way there pretty quickly, too, but then it took me several hours to get my head around the handling of the cccl_iterator_t::state.

(Next I'll try to plug in the POC map() function.)

rwgk commented 1 week ago

Georgii's POC map() function is now also integrated into the fancy_iterators branch (under the name cu_map(), to not shadow a Python built-in function). All existing tests and the new cu_map() test pass at this commit:

eriknw commented 5 days ago

I have a quick comment regarding map.

I share concerns about shadowing Python builtins, but I think map is singularly special in a functional library. It is likely to be heavily used in code, examples, and documentations. Hence, it is possible to reinforce the conventions for how to use map.

I would prefer to keep the name map if we can establish two conventions such as:

from cuda.parallel.iterators import map as cumap

and

import cuda.parallel.iterators as cuit

cuit.map(...)

Also, I prefer cumap to cu_map. We have cuml, cudf, cugraph, cupy, cucim, etc. libraries. I don't see other functions with cu_ prefix.