Closed landmanbester closed 3 years ago
Can you tell me dims and dtypes on the inputs? Will need to mock something up.
Also, you may want to enable caching. That could throw off your timings.
Something to think about is that this is already pretty much pure numpy. Numpy is fast, so we shouldn't expect miracles here. Still prodding around though.
Yeah probably just switching from numpy to numba probably won't get you much but I was thinking of parallelising the peak finding and psf subtraction as those are the most expensive steps. PSF subtraction is trivial wit prange but would need to divide and conquer for peak finding
You are still using mostly numpy functions that do not run in parallel. Have you tried rolling your own parallel peakfinding and subtraction functions?
On Wed, 17 Mar 2021, 13:39 Landman Bester, @.***> wrote:
Here is my attempt at a fast version of MFS Hogbom:
@njit(nogil=True, fastmath=True, inline='always') def hogbom(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000): nx, ny = ID.shape IR = ID.copy() IRsearch = IR*2 pq = IRsearch.argmax() p = pq//ny q = pq - pny IRmax = np.sqrt(IRsearch[p, q]) tol = pfIRmax k = 0 while IRmax > tol and k < maxit: xhat = IR[p, q] x[p, q] += gamma xhat IR -= gamma xhat PSF[nx-p:2nx - p, ny-q:2ny - q] IRsearch = IR*2 pq = IRsearch.argmax() p = pq//ny q = pq - pny IRmax = np.sqrt(IRsearch[p, q]) k += 1 return x, IR
Interestingly jitting it makes almost no difference, even if I run it for many iterations. @JSKenyon https://github.com/JSKenyon this is one of self contained little problems I promised you in case you are interested / have the time to look at it. It is faster than the standard version which does peak finding in the absolute value of the image whereas I search in the square. I have also tried implementing it in jax but no cigar as of yet. Here is the attempt
def hogbom_jax(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000): nx, ny = ID.shape IR = ID.copy() IRsearch = IR*2 IRmax = IRsearch.max() tol = pfnp.sqrt(IRmax) def fun(input): tol, IRmax, IR, IRsearch, PSF, x, shape, tol, gamma = input[0],\ input[1], input[2], input[3], input[4], input[5], input[6],\ input[7], input[8] nx, ny = shape[0], shape[1] pq = IRsearch.argmax() p = pq//ny q = pq - pny xhat = IR[p, q] x = index_add(x, (p, q), gamma xhat) Ip = slice(nx-p, 2nx-p, 1) Iq = slice(ny-q, 2ny-q, 1) IR = IR - gamma xhat PSF[Ip, Iq] IRsearch = IR**2 IRmax = IRsearch[p, q] return [tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma] out = lax.while_loop(lambda a: jnp.sqrt(a[1]) >= a[0], lambda a:fun(a), init_val=[tol, IRmax, IR, IRsearch, PSF, x, (nx, ny), tol, gamma]) return out[5], IR[2]
Still a bit new to jax so I'm probably not using it right but it doesn't seem to like the IR = IR - gamma xhat PSF[Ip, Iq] bit, complains with
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
So it looks like I can get it working with dynamic slicing but then I won't be able to jit it.
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/ratt-ru/pfb-clean/issues/43, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB4RE6UP6IKCHC5FTFT7JT3TECIFPANCNFSM4ZKNLTFQ .
That's what I'm trying now, was wondering why the numba version is sometimes slightly slower than the numpy one, thought I might have missed something
For the record, with @JSKenyon's help, we got a jax version working but it doesn't seem to offer any benefit (at least not on a CPU). Here it is
@jit
def hogbom_jax(ID, PSF, x, gamma=0.1, pf=0.1, maxit=5000):
nx, ny = ID.shape
IR = jnp.array(ID, copy=True)
IRsearch = jnp.square(IR)
pq = jnp.argmax(IRsearch)
p = pq//ny
q = pq - p*ny
IRmax = jnp.sqrt(IRsearch[p, q])
tol = pf*IRmax
k = 0
def cond_func(inputs):
IRmax, IR, IRsearch, PSF, x, loc, tol, gamma, k = inputs
return (k < maxit) & (IRmax > tol)
def body_func(inputs):
IRmax, IR, IRsearch, PSF, x, loc, tol, gamma, k = inputs
nx, ny = IR.shape
p, q = loc
xhat = IR[p, q]
x = index_add(x, (p, q), gamma * xhat)
modconv = lax.dynamic_slice(PSF, [nx-p, ny-q], [nx, ny])
IR = IR - gamma * xhat * modconv
IRsearch = jnp.square(IR)
pq = IRsearch.argmax()
p = pq//ny
q = pq - p*ny
IRmax = jnp.sqrt(IRsearch[p, q])
return (IRmax, IR, IRsearch, PSF, x, (p, q), tol, gamma, k+1)
init_val = (IRmax, IR, IRsearch, PSF, x, (p, q), tol, gamma, k)
out = lax.while_loop(cond_func, body_func, init_val)
return out[4], out[1]
That's what I'm trying now, was wondering why the numba version is sometimes slightly slower than the numpy one, thought I might have missed something
Yeah numba isn't going to do any better with NumPy semantics, single-threaded.
What you could do is add parallel to the decorator as that does invoke some parallel compilation optimisations
https://numba.pydata.org/numba-doc/latest/user/parallel.html
Somewhat surprisingly the peak finding doesn't dominate the compute at all. PSF subtraction is way more expensive, especially for cubes. I've taken the DDF approach and simply parallelised this using numexpr which seems to work better than naively sticking a prange in one of the loops with numba.
On the jax front, @JSKenyon saw almost an order of magnitude speedup when he ran Hogbom on his GPU. I'll think I'll have a crack at writing the full CLEAN minor cycle in jax. With jaxlets I can probably also write the SARA minor cycle in jax...
Here is my attempt at a fast version of MFS Hogbom:
Interestingly jitting it makes almost no difference, even if I run it for many iterations. @JSKenyon this is one of self contained little problems I promised you in case you are interested / have the time to look at it. It is faster than the standard version which does peak finding in the absolute value of the image whereas I search in the square. I have also tried implementing it in jax but no cigar as of yet. Here is the attempt
Still a bit new to jax so I'm probably not using it right but it doesn't seem to like the
IR = IR - gamma * xhat * PSF[Ip, Iq]
bit, complains withSo it looks like I can get it working with dynamic slicing but then I won't be able to jit it.