ratt-ru / pfb-imaging

Preconditioned forward/backward clean algorithm
MIT License
6 stars 5 forks source link

Parallelise wavelet decompositions #2

Closed landmanbester closed 1 year ago

landmanbester commented 4 years ago

Currently the wavelet decompositions take about as much time as the major cycles, actually slightly longer when imaging only one of the ESO137 data sets. The basic structure of the decomposition is as follows:

image_out = zeros_like(image)
for b in range(nbasis):
        for c in range(nchannel):
                decompose image[c, :, :] into wavelet basis 
        threshold average coefficients where the average is taken over frequency 
        for c in range(nchannel):
                image_b = reconstruct image[c, :, :] from thresholded wavelet coefficients 
        image_out += image_b

In principle, everything is embarrassingly parallel you just need a reduction over the frequency axis after the decomposition and then over wavelet basis at the end. Suggestions for an efficient way to implement this are welcome (hint hint @sjperkins). My current serial implementation is here with the frequency axis taken care of in dot and hdot of the psi operator here. Note these are calls to cythonized functions but I am not sure if the gil is fully released during calls to waverec2 and wavedec2.

sjperkins commented 4 years ago

Note these are calls to cythonized functions but I am not sure if the gil is fully released during calls to waverec2 and wavedec2.

It looks like the gil does get released so you should be able to obtain thread parallelism, if there's enough compute to peform while the gil is released:

https://github.com/PyWavelets/pywt/search?q=nogil&unscoped_q=nogil

I'll think some more about the other parallelism concerns.

landmanbester commented 4 years ago

I finally got around to testing this properly and I am running into a few issues. Firstly, I think I overcomplicated our lives a bit @sjperkins. Something which might not have been clear is that the wavelet coefficients are not always the same size as the decomposed image (try with nx = ny = 250 for example). Since the coefficients for the Dirac basis are always the same size as the image, this meant that I couldn't simply keep the coefficients for all bases in a single array, unless I used an object array (as I did for weights_21). I think that was the main reason why I didn't include the loop over basis inside the dot and hdot functions. However, I realised that I can simply pad the array to the correct shape for the Dirac basis (i.e. when basis=='self'). I have refactored the functions to include the loop over basis on the test_new_prox branch. I think they now have a much cleaner interface.

The second thing I noticed was that hdot and dot ware swapped around in the prox_21 function when using dask_psi. I have no idea how I was getting comparable results. I think there was some weirdness going on because the alphas were not of the correct shape (i.e. [nbasis, nband, ntot] where ntot is the shape of the wavelet coefficients corresponding to a single image). In any case, comparing the prox_21 results I got a maximum absolute difference of only 1e-5, which might explain why the results looked about the same. I have left the prox_21 function for later so don't worry about that.

The main thing that I wanted to check is that I can get some sort of speed up with the dask versions of the functions. I have implemented them and confirmed that they give the same results as the non-dask versions. When running the dask versions the cores do seem to spin up but they are mostly red on htop and the functions are actually slower (at least on my laptop). Things don't seem to improve when I make the image larger either. If you have some time @sjperkins can you please have a look at what I've done there? The comparison is in the test_psi_operator script. Sorry, I probably could have gotten to this stage myself before getting you involved.

sjperkins commented 4 years ago

I agree that the refactoring makes thing cleaner.

With regard to the speedup, I'm inclined to look at the PyWavelets code. However, the PyWavelets Developers do seem to think that it's possible to process multiple images in multiple threads and get some sort of speed-up. https://github.com/PyWavelets/pywt/issues/371#issuecomment-389567627

In fact there's a multi-threaded batch-processing example: https://github.com/PyWavelets/pywt/blob/master/demo/batch_processing.py that it may be worth running in order to see what sort of speed-ups are achievable.

If it's not the PyWavelets, then it may be that the pure python loop code is the bottleneck.

I could probably take a look early next week.

sjperkins commented 4 years ago

Also, would it be possible to create a PR on your dev branch in which we can examine the code and that can serve as a springboard for PR's I wish to base on your development branch?

It doesn't have to be a big deal, its just to make things easier.

landmanbester commented 4 years ago

Yep, of course. The dev branch is project_init. I will tidy things up a bit and merge into master soon. I just wanted to properly initialise the project before putting it in master. The testing framework in particular. Thanks for the help

landmanbester commented 4 years ago

@sjperkins, just to clarify, I think we should focus on parallelising just the wavelet decompositions for now (so just the psi.dot and psi.hdot functions). Don't worry about building the graph for the full prox_21 function

sjperkins commented 3 years ago

Existing wavelet numba functionality doesn't offer any improvement over PyWavelets.

Possible reasons:

  1. Wavelet decompositions + reconstructions are inherently data intensive operations due to the need to perform 1D convolutions across each axis. This is more problematic on slowly changing axes where 1D reads and writes are required during convolution to ensure spatial locality. Thus, it may be difficult to achieve high arithmetic intensity relative to the size of the input data. @mreineck do you have any thoughts on this?

  2. This was always at the back of my mind, but the following code https://github.com/ratt-ru/pfb-clean/blob/07ec2fcd518541bf3cf586a0741c65a0db89fc94/pfb/wavelets/wavelets.py#L207-L231 is incurring a memory allocation for each 1D chunk of data it's convolving. It may be worth hoisting the allocation out of the loop and using pointer arithmetic for the copies.

landmanbester commented 3 years ago

I dug into this a little on a small 1k x 1k example. Here is a snakeviz of the cProfile report running the pywt versions

profile_psi

The highlighted core.py item is a from_array which takes about half the time. I didn't realise they were so expensive. I guess the solution is expressing more of the computation as a graph. I'll see if I can work towards that. I'll try to repeat this experiment for the numba versions

sjperkins commented 3 years ago

Ah looks like its hashing the underlying numpy arrays. Note that if you supply a name argument to from_array it'll use this as the unique token instead -- 'myname-' + uuid.uuid4().hex should work nicely.

sjperkins commented 3 years ago

The following would also work, as long as the underlying data doesn't change (this also applies to the last comment).

da.from_array(array, name='myname-' + hash(id(array)))
landmanbester commented 3 years ago

The following would also work, as long as the underlying data doesn't change (this also applies to the last comment).

da.from_array(array, name='myname-' + hash(id(array)))

Thank you 八

landmanbester commented 3 years ago

The original arrays can't be mutated in this case but setting name=False does the trick and looks much better

profile_psi

One level up we see that (I am only doing a single major cycle here) most of the time is still spent in the primal dual

profile_psi

Thread occupancy seems pretty low so maybe turning the primal dual into a graph would help here

landmanbester commented 1 year ago

Completed in https://github.com/ratt-ru/pfb-clean/pull/58