ratt-ru / pfb-imaging

Preconditioned forward/backward clean algorithm
MIT License
7 stars 5 forks source link

Parallel kron_matvec #33

Closed landmanbester closed 3 years ago

landmanbester commented 3 years ago

I need to impose a small spatial correlation in the preconditioner for extended emission for the point source separation to work. This means that I have to invert a covariance matrix which can be factorised as a Kronecker product over dimensions

A = A1 otimes A2 otime ....

Unfortunately, even though the covariance matrix itself is Toeplitz, it's inverse is not so I can't use an FFT based matrix vector product. I can, however, compute the inverse in each dimension explicitly and use the efficient Kronecker matrix vector product algorithm which can be coded up as

def kron_matvec(A, b):
    D = len(A)
    N = b.size
    x = b
    for d in range(D):
        Gd = A[d].shape[0]
        X = np.reshape(x, (Gd, N//Gd))
        Z = np.zeros((Gd, N//Gd), dtype=A[0].dtype)
        for i in range(Gd):
            for j in range(N//Gd):
                for k in range(Gd):
                    Z[i, j] += A[d][i, k] * X[k, j]
        x = Z.T.flatten()
    return x

where A is a tuple containing the individual matrices going into the Kronecker product and b is the vector it is acting on. I have jitted the function but it is still debilitatingly slow. @JSKenyon @sjperkins I would be curious if you have any ideas for speeding it up. Note that we will eventually need this for the posterior smoothing of the gains as well

JSKenyon commented 3 years ago

There are a few simple things you can try. I am assuming that A is a list? The double look-up A[d][i,k] might be slow. You could check this by trying the specific case where len(A) is one and you omit the first lookup.

You can also lift the N//Gd out into its own variable. You keep recomputing it unnecessarily.

I would also try replacing flatten with ravel. I am not sure, but flatten might be making unnecessary copies.

JSKenyon commented 3 years ago

Could you try this?

def kron_matvec(A, b):
    D = len(A)
    N = b.size
    x = b
    for d in range(D):
        Gd = A[d].shape[0]
        NGd = N//Gd
        X = np.reshape(x, (Gd, NGd))
        Z = np.zeros((Gd, NGd), dtype=A[0].dtype)
        Ad = A[d]
        for i in range(Gd):
            for j in range(NGd):
                for k in range(Gd):
                    Z[i, j] += Ad[i, k] * X[k, j]
        x = Z.T.ravel()
    return x
landmanbester commented 3 years ago

Cool, thanks. Let me give that a try

landmanbester commented 3 years ago

Good call. On a 8x1024x1024 cube your version takes 150s as apposed to 230s. It is still a bit slow though. If I drop the correlation in frequency then I can parallelise over the frequency axis efficiently because then I have

A = I otimes Ax otimes Ay

where I is the identity. That should be pretty easy to do. Do you think we can efficiently parallelise it further or is it too simple an operation? prange over i maybe?

JSKenyon commented 3 years ago

I am not sure really. You can always try it and see. What are the values of D, Gd and NGd for your example problem? Just want to make sure I have the correct picture in my head.

landmanbester commented 3 years ago

D is the number of dimensions, 3 in this case. Gd will be 8 then 1024 then 1024 for this example so NGd will be 1024x1024 then 8x1024 then 1024 x 8

JSKenyon commented 3 years ago

Is Ad[i, k] a single entry, or is it a vector?

landmanbester commented 3 years ago

Single entry

landmanbester commented 3 years ago

I think there is still something weird happening inside the function. Theoretically the time complexity should go like

nv^2 nx ny + nx^2 ny nv + ny^2 nv nx = (nv+nx+ny) nv nx * ny

whereas the FFT goes like

nv nx ny log_2 (nv nx ny)

meaning, if the function is performing optimally, I should be slower than the FFT by a factor of

(nv + nx + ny)/ log_2 (nv nx ny)

which in this example is about 90. I see a factor closer to 1500... Still room for improvement

JSKenyon commented 3 years ago

I will be back to work in earnest on Monday and can probably invest some time into this. Interested to know what is holding it back.

landmanbester commented 3 years ago

Cool, thanks. Let's pick it up then

JSKenyon commented 3 years ago

@landmanbester Do you have an example script lying around? Just to make sure I use the same inputs as you.

landmanbester commented 3 years ago

Sure, have a look at

https://github.com/ratt-ru/pfb-clean/blob/cleanup_options/pfb/test/test_kron_matvec.py

I'll probably move this functionality over to africanus once we have something performing sensibly. Thanks for taking a look

JSKenyon commented 3 years ago

You can give this a go - no promises that it is correct yet, nor is it as optimised as possible:

@jit(nopython=True, fastmath=True, parallel=False, cache=True, nogil=True)
def kron_matvec(A, b):
    D = len(A)
    N = b.size
    x = b

    for d in range(D):
        Gd = A[d].shape[0]
        NGd = N//Gd
        X = np.reshape(x, (Gd, NGd))
        Z = A[d].dot(X).T
        x[:] = Z.ravel()
    return x

I seems to agree with explicitly doing the kronecker products first but you will be more able to check.

landmanbester commented 3 years ago

Indeed! Thanks @JSKenyon. This is much faster and also closer to how I had originally coded it up. I ended up writing the loops out explicitly because I thought numba would prefer that. Strange that it makes it so much slower. I think part of the acceleration comes from the dot being parallelised (on my machine at least). Note I had to remove the [:] on x[:] to get the right result

landmanbester commented 3 years ago

I think this is good enough for my immediate purposes but I'll leave this one open because I would like to understand why the explicitly coded matrix matrix product is so much slower. It might also not be sufficient to rely on numpy's internal parallelization for A[d].dot(X).T because I've seen this being very architecture dependent

JSKenyon commented 3 years ago

This is an alternative which is a bit slower but not as slow as the original and makes use of the explicit matrix multiply:

@jit(nopython=True, fastmath=True, parallel=False, cache=True, nogil=True)
def kron_matvec2(A, b):
    D = len(A)
    N = b.size
    x = b

    for d in range(D):
        Gd = A[d].shape[0]
        NGd = N//Gd
        X = np.reshape(x, (Gd, NGd))
        Z = np.zeros((Gd, NGd), dtype=A[0].dtype)
        Ad = A[d]
        for i in range(Gd):
            for j in range(Gd):
                for k in range(NGd):
                    Z[j, k] += Ad[i, j] * X[i, k]
        x[:] = Z.T.ravel()
    return x
JSKenyon commented 3 years ago

I think that the discrepancies stem from the use of vectorised operations in np.dot. My experiments seem to show that the np.dot example has only 25% of the instructions present in the hand coded case. I would need to dig into the compiled code to be sure. Shelving that idea for now, but this has been quite informative.

landmanbester commented 3 years ago

Incorporated in https://github.com/ratt-ru/pfb-clean/pull/49