ratt-ru / pfb-imaging

Preconditioned forward/backward clean algorithm
MIT License
7 stars 5 forks source link

`prange` doesn't work in overloaded functions #116

Open landmanbester opened 2 months ago

landmanbester commented 2 months ago

Trying to use prange inside an overload eg.

@njit(**JIT_OPTIONS, parallel=True)
def update(x, xp, r, rp, p, Ap, alpha):
    return update_impl(x, xp, r, rp, p, Ap, alpha)

def update_impl(x, xp, r, rp, p, Ap, alpha):
    return NotImplementedError

@overload(update_impl, jit_options=JIT_OPTIONS, parallel=True)
def nb_update_impl(x, xp, r, rp, p, Ap, alpha):
    if x.ndim==3:
        def impl(x, xp, r, rp, p, Ap, alpha):
            nband, nx, ny = x.shape
            for b in range(nband):
                for i in prange(nx):
                    for j in range(ny):
                        x[b, i, j] = xp[b, i, j] + alpha * p[b, i, j]
                        r[b, i, j] = rp[b, i, j] + alpha * Ap[b, i, j]
            return x, r
    elif x.ndim==2:
        def impl(x, xp, r, rp, p, Ap, alpha):
            nx, ny = x.shape
            for i in prange(nx):
                for j in range(ny):
                    x[i, j] = xp[i, j] + alpha * p[i, j]
                    r[i, j] = rp[i, j] + alpha * Ap[i, j]
            return x, r
    else:
        raise ValueError("update only implemented for 2D or 3D arrays")

    return impl

results in the following warning message during compilation

/home/bester/.venv/pfb/lib/python3.10/site-packages/numba/core/typed_passes.
py:336: NumbaPerformanceWarning:
The keyword argument 'parallel=True' was specified but no transformation for
 parallel execution was possible.

To find out why, try turning on parallel diagnostics, see https://numba.read
thedocs.io/en/stable/user/parallel.html#diagnostics for help.

File "../../software/pfb-imaging/pfb/opt/pcg.py", line 21:
@njit(**JIT_OPTIONS, parallel=True)
def update(x, xp, r, rp, p, Ap, alpha):
^

I've seen this before when doing nested function calls to to prange (eg. here). For this function I get the same warning but I do actually see multiple threads spinning up whereas the overloaded implementation doesn't seem to parallelize at all. I wonder if this is a bug in numba or if I'm trying something that is not supported

landmanbester commented 2 months ago

Looking at the parallel diagnostics gives

In [8]: update.parallel_diagnostics()

================================================================================
 Parallel Accelerator Optimizing:  Function update, /home/bester/software/pfb-
imaging/pfb/opt/pcg.py (20)
================================================================================
No source available
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------