ratt-ru / pfb-imaging

Preconditioned forward/backward clean algorithm
MIT License
7 stars 5 forks source link

FFH #43

Closed landmanbester closed 3 years ago

landmanbester commented 3 years ago

Here is my attempt at a fast version of MFS Hogbom:

@njit(nogil=True, fastmath=True, inline='always')
def hogbom(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000):
    nx, ny = ID.shape
    IR = ID.copy()
    IRsearch = IR**2
    pq = IRsearch.argmax()
    p = pq//ny
    q = pq - p*ny
    IRmax = np.sqrt(IRsearch[p, q])
    tol = pf*IRmax
    k = 0
    while IRmax > tol and k < maxit:
        xhat = IR[p, q]
        x[p, q] += gamma * xhat
        IR -= gamma * xhat * PSF[nx-p:2*nx - p, ny-q:2*ny - q]
        IRsearch = IR**2
        pq = IRsearch.argmax()
        p = pq//ny
        q = pq - p*ny
        IRmax = np.sqrt(IRsearch[p, q])
        k += 1
    return x, IR

Interestingly jitting it makes almost no difference, even if I run it for many iterations. @JSKenyon this is one of self contained little problems I promised you in case you are interested / have the time to look at it. It is faster than the standard version which does peak finding in the absolute value of the image whereas I search in the square. I have also tried implementing it in jax but no cigar as of yet. Here is the attempt

def hogbom_jax(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000):
    nx, ny = ID.shape
    IR = ID.copy()
    IRsearch = IR**2
    IRmax = IRsearch.max()
    tol = pf*np.sqrt(IRmax)
    def fun(input):
        tol, IRmax, IR, IRsearch, PSF, x, shape, tol, gamma = input[0],\
                input[1], input[2], input[3], input[4], input[5], input[6],\
                input[7], input[8]
        nx, ny = shape[0], shape[1]
        pq = IRsearch.argmax()
        p = pq//ny
        q = pq - p*ny
        xhat = IR[p, q]
        x = index_add(x, (p, q), gamma * xhat)
        Ip = slice(nx-p, 2*nx-p, 1)
        Iq = slice(ny-q, 2*ny-q, 1)
        IR = IR - gamma * xhat * PSF[Ip, Iq]
        IRsearch = IR**2
        IRmax = IRsearch[p, q]
        return [tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma]
    out = lax.while_loop(lambda a: jnp.sqrt(a[1]) >= a[0], lambda a:fun(a),
            init_val=[tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma])
    return out[5], IR[2]

Still a bit new to jax so I'm probably not using it right but it doesn't seem to like the IR = IR - gamma * xhat * PSF[Ip, Iq] bit, complains with

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. 
To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice 
(JAX does not support dynamically sized arrays within JIT compiled functions).

So it looks like I can get it working with dynamic slicing but then I won't be able to jit it.

JSKenyon commented 3 years ago

Can you tell me dims and dtypes on the inputs? Will need to mock something up.

JSKenyon commented 3 years ago

Also, you may want to enable caching. That could throw off your timings.

JSKenyon commented 3 years ago

Something to think about is that this is already pretty much pure numpy. Numpy is fast, so we shouldn't expect miracles here. Still prodding around though.

landmanbester commented 3 years ago

Yeah probably just switching from numpy to numba probably won't get you much but I was thinking of parallelising the peak finding and psf subtraction as those are the most expensive steps. PSF subtraction is trivial wit prange but would need to divide and conquer for peak finding

bennahugo commented 3 years ago

You are still using mostly numpy functions that do not run in parallel. Have you tried rolling your own parallel peakfinding and subtraction functions?

On Wed, 17 Mar 2021, 13:39 Landman Bester, @.***> wrote:

Here is my attempt at a fast version of MFS Hogbom:

@njit(nogil=True, fastmath=True, inline='always') def hogbom(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000): nx, ny = ID.shape IR = ID.copy() IRsearch = IR*2 pq = IRsearch.argmax() p = pq//ny q = pq - pny IRmax = np.sqrt(IRsearch[p, q]) tol = pfIRmax k = 0 while IRmax > tol and k < maxit: xhat = IR[p, q] x[p, q] += gamma xhat IR -= gamma xhat PSF[nx-p:2nx - p, ny-q:2ny - q] IRsearch = IR*2 pq = IRsearch.argmax() p = pq//ny q = pq - pny IRmax = np.sqrt(IRsearch[p, q]) k += 1 return x, IR

Interestingly jitting it makes almost no difference, even if I run it for many iterations. @JSKenyon https://github.com/JSKenyon this is one of self contained little problems I promised you in case you are interested / have the time to look at it. It is faster than the standard version which does peak finding in the absolute value of the image whereas I search in the square. I have also tried implementing it in jax but no cigar as of yet. Here is the attempt

def hogbom_jax(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000): nx, ny = ID.shape IR = ID.copy() IRsearch = IR*2 IRmax = IRsearch.max() tol = pfnp.sqrt(IRmax) def fun(input): tol, IRmax, IR, IRsearch, PSF, x, shape, tol, gamma = input[0],\ input[1], input[2], input[3], input[4], input[5], input[6],\ input[7], input[8] nx, ny = shape[0], shape[1] pq = IRsearch.argmax() p = pq//ny q = pq - pny xhat = IR[p, q] x = index_add(x, (p, q), gamma xhat) Ip = slice(nx-p, 2nx-p, 1) Iq = slice(ny-q, 2ny-q, 1) IR = IR - gamma xhat PSF[Ip, Iq] IRsearch = IR**2 IRmax = IRsearch[p, q] return [tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma] out = lax.while_loop(lambda a: jnp.sqrt(a[1]) >= a[0], lambda a:fun(a), init_val=[tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma]) return out[5], IR[2]

Still a bit new to jax so I'm probably not using it right but it doesn't seem to like the IR = IR - gamma xhat PSF[Ip, Iq] bit, complains with

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

So it looks like I can get it working with dynamic slicing but then I won't be able to jit it.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/ratt-ru/pfb-clean/issues/43, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB4RE6UP6IKCHC5FTFT7JT3TECIFPANCNFSM4ZKNLTFQ .

landmanbester commented 3 years ago

That's what I'm trying now, was wondering why the numba version is sometimes slightly slower than the numpy one, thought I might have missed something

landmanbester commented 3 years ago

For the record, with @JSKenyon's help, we got a jax version working but it doesn't seem to offer any benefit (at least not on a CPU). Here it is

@jit
def hogbom_jax(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000):
    nx, ny = ID.shape
    IR = jnp.array(ID, copy=True)
    IRsearch = jnp.square(IR)
    pq = jnp.argmax(IRsearch)
    p = pq//ny
    q = pq - p*ny
    IRmax = jnp.sqrt(IRsearch[p, q])
    tol = pf*IRmax
    k = 0

    def cond_func(inputs):

        IRmax, IR, IRsearch, PSF, x, loc, tol, gamma, k = inputs

        return (k < maxit) & (IRmax > tol)

    def body_func(inputs):
        IRmax, IR, IRsearch, PSF, x, loc, tol, gamma, k = inputs
        nx, ny = IR.shape
        p, q = loc
        xhat = IR[p, q]
        x = index_add(x, (p, q), gamma * xhat)
        modconv = lax.dynamic_slice(PSF, [nx-p, ny-q], [nx, ny])
        IR = IR - gamma * xhat * modconv
        IRsearch = jnp.square(IR)
        pq = IRsearch.argmax()
        p = pq//ny
        q = pq - p*ny
        IRmax = jnp.sqrt(IRsearch[p, q])
        return (IRmax, IR, IRsearch, PSF, x, (p, q), tol, gamma, k+1)

    init_val = (IRmax, IR, IRsearch, PSF, x, (p, q), tol, gamma, k)
    out = lax.while_loop(cond_func, body_func, init_val)

    return out[4], out[1]
sjperkins commented 3 years ago

That's what I'm trying now, was wondering why the numba version is sometimes slightly slower than the numpy one, thought I might have missed something

Yeah numba isn't going to do any better with NumPy semantics, single-threaded.

What you could do is add parallel to the decorator as that does invoke some parallel compilation optimisations

https://numba.pydata.org/numba-doc/latest/user/parallel.html

landmanbester commented 3 years ago

Somewhat surprisingly the peak finding doesn't dominate the compute at all. PSF subtraction is way more expensive, especially for cubes. I've taken the DDF approach and simply parallelised this using numexpr which seems to work better than naively sticking a prange in one of the loops with numba.

On the jax front, @JSKenyon saw almost an order of magnitude speedup when he ran Hogbom on his GPU. I'll think I'll have a crack at writing the full CLEAN minor cycle in jax. With jaxlets I can probably also write the SARA minor cycle in jax...

landmanbester commented 3 years ago

Added in https://github.com/ratt-ru/pfb-clean/pull/49