Python-for-HPC / numbaWithOpenmp

BSD 2-Clause "Simplified" License
3 stars 4 forks source link

Use ExternalFunction for OpenMP runtime calls #5

Closed ggeorgakoudis closed 1 year ago

ggeorgakoudis commented 1 year ago

Fix #4

Reference an existing issue

ggeorgakoudis commented 1 year ago

Right, it won't work without jitting. is this functionality a requirement? Can we provide an overload when jitting is disabled? The ExternalFunction solution works for both host and device code as-is.

DrTodd13 commented 1 year ago

Right, it won't work without jitting. is this functionality a requirement? Can we provide an overload when jitting is disabled? The ExternalFunction solution works for both host and device code as-is.

I think it is a requirement. @stuartarchibald Any suggestions Stuart?

stuartarchibald commented 1 year ago

I'm not sure how this needs to appear to the LLVM pass in terms of the declaration of the OpenMP functions, but from the above I'm guessing that if ExternalFunction is working already then the following sort of thing would make a similar API work both with and without the JIT compiler:

from numba import njit, types
from numba.extending import overload
import ctypes, ctypes.util

# ------------------------------------------------------------------------------
# This is the public API that you want your package to expose, it's a pure
# Python API.

def omp_get_num_threads():
    """Get the number of OpenMP threads"""
    # This uses ctypes to bind to the OpenMP library, could also be cffi.

    # This looks for the name of the GNU OpenMP library
    libompname = ctypes.util.find_library("gomp")

    if libompname is None:
        raise RuntimeError("Cannot find suitable OpenMP library.")

    # dlopen the library
    libomp = ctypes.CDLL(libompname)

    # create the ctypes binding to the function
    ct_omp_get_num_threads = libomp.omp_get_num_threads
    ct_omp_get_num_threads.restype = ctypes.c_int
    ct_omp_get_num_threads.argtypes = ()

    # call the function and return the result
    return ct_omp_get_num_threads()

# ------------------------------------------------------------------------------
# This is a Numba compatible overload for the public API function
# `omp_get_num_threads` defined above.

@overload(omp_get_num_threads)
def ol_omp_get_num_threads():
    # Declare `omp_get_num_threads` as an external function.
    # On the CPU target the OpenMP runtime library will need to be in process
    # for this to succeed.
    # On the CUDA target this is what presents the opportunity to make the
    # runtime call, the declaration is present and the bitcode can be linked in.
    fnty = types.ExternalFunction("omp_get_num_threads", types.int32())
    def impl():
        return fnty()
    return impl

# ------------------------------------------------------------------------------
# Demo code

# This is a python call to a python function
print("Python call to python API omp_get_num_threads():", omp_get_num_threads())

@njit
def compiled_call():
    return omp_get_num_threads()

# This is a compiled call that resolves through the `@overload`.
print("Numba compiled call to omp_get_num_threads():", compiled_call())

# Print the LLVM IR, note that the declaration is made for `omp_get_num_threads`
# but there is no definition.
print("LLVM IR:")
print(compiled_call.inspect_llvm(compiled_call.signatures[0]))

I've used the ctypes module to make the binding in the above, but you could continue using cffi if that's convenient. The important part is exposing a public API that you control as a set of functions/function stubs that can then be overload'd.

@DrTodd13 for a reference point, the above technique is similar to Numba's get_num_threads() (and that family of functions).

Hope this helps!

DrTodd13 commented 1 year ago

@stuartarchibald @ggeorgakoudis Used Stuart's way but dynamically generated everything to reduce boilerplate. Python and Numba invocations should all work now the way we want.

DrTodd13 commented 1 year ago

@stuartarchibald This approach seems to work for Python, Numba (outside of a target region), and inside a CPU device target region. However, for CUDA target regions, I get: Traceback (most recent call last): File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/errors.py", line 817, in new_error_context yield File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 252, in lower_block self.lower_inst(inst) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 401, in lower_inst val = self.lower_assign(ty, inst) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 583, in lower_assign return self.lower_expr(ty, value) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 1119, in lower_expr res = self.lower_call(resty, expr) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 850, in lower_call res = self._lower_call_normal(fnty, expr, signature) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/lowering.py", line 1082, in _lower_call_normal impl = self.context.get_function(fnty, signature) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/base.py", line 571, in get_function return self.get_function(type(fn), sig) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/base.py", line 580, in get_function return self.get_function(fn, sig, _firstcall=False) File "/mnt/home/taanders/OpenMP_in_Numba/code/ggeorgakoudis/numba/numba/core/base.py", line 582, in get_function raise NotImplementedError("No definition for lowering %s%s" % (key, sig)) NotImplementedError: No definition for lowering <class 'numba.core.types.functions.Function'>() -> int32 Adding overload(omp_get_num_threads, target="cuda") doesn't seem to help. Suggestions?