Open frankong opened 5 years ago
here is a possible approach: https://lernapparat.de/2d-wavelet-transform-pytorch/
This is addressed by PR #32!
Thanks, Sid!
Hi everyone,
I see this feature was initally added and then removed from the master branch. Would it be possible to know why?
I would be really interested to get this working - I have been working on some compressed sensing type of recon with sigpy, keeping an eye to recon time: the wavelet transform is definitely the bottleneck of the pipeline. Is there any plan to allow wavelet to run on gpu?
thanks a lot for this amazing tool! Marco
Hi,
I think @frankong was worried about the correctness of my implementation. I would like to re-visit it, but do not have time right now. That being said, if you are able to resurrect the code and clarify that it's in "beta", I'll be happy to re-include it.
Hi Sid,
I have been looking a bit into the solution proposed here
I ran some quick tests on 2D data, comparing forward transform using this implementation and the one in PyWavelets. Qualitatively, it looks reasonably similar, although coefficient absolute values are a bit different. It also has a different size, probably the zero padding is applying in a slightly different way, as when I look at the tiled coefficients they appear a bit shifted? Another thing I noticed, when looking at residual error of a forward+backward transform, this new implementation shows a slighltly higher error compared to the forward+backward of pywavelet (still the effect on the input image could not be appreciated). These were just qualitative tests, I am happy to go a bit more in detail with the comparison if needed...Do you rememeber by any chance what the issue was when you first incorporated this? I can try to focus on that...
Regarding the performace, I had hard time to get decent computation times. Initially, this solution was way slower than the PyWavelet implementation, although I could see some activity on the GPU. Using the nvidia profiler (NSight), I could see that the GPU workload was getting much more intermittent right after the first iteration of the GradientMethod. I tried to move into the initialization of the Wavelet linear operator as much operations as possible, but that didn't really help. The key move was to downgrade my cuda version, from 11.4 to 11.2 (I am running this on Windows 10), and all of sudden I got a factor 10 of reduction in recon time (and I can see that the GPU is working consistently). Now recon times are more comparable with the pywavelet implementation. Actually on 3D data, this is now more the 50% faster, while it is roughly 50% slower than the original implementation on 2D data... I am definitely not an expert on code optimization (even worse on GPU), but I believe there is some room for improvement...
Look forward to your feedbacks (and thanks a lot for the effort you put into sigpy!!!)
Marco
Hi Marco,
Thanks much for the tests! I will personally be a swamped for a few months but if you're willing to take the lead, I'd be happy to review any pull-requests on this.
The following are what I envision would need to be done before the GPU version can be implemented:
modes
argument used in pywt
(details: https://pywavelets.readthedocs.io/en/latest/regression/modes.html).On the performance difference, I think we can revisit it after (2) is done! Once that's set as a baseline, we can look at optimizations.
Thanks much for all you've reported on so far already! Please let me know your thoughts on the above, and if you're interested in pursuing this.
Hi Sid, Yes, it sounds like a fair plan. I have had a quick look at 1) and 2) and it shouldn't be anything too hard to achieve. Not sure how much time I can dedicate to this in the coming weeks, but August looks definitely less busy. I will keep you posted.
Thanks
Is your feature request related to a problem? Please describe. There is no wavelet transform in GPU. Currently, SigPy moves array to CPU and uses pywavelet to perform the wavelet transform, which is the main bottleneck for compressed sensing MRI reconstructions.
Describe the solution you'd like A GPU wavelet transform can leverage the GPU multi-channel convolution operations already wrapped in SigPy. Low pass and high pass filtering can be done in one operation using the output channels. Strides can be used to incorporate subsampling.
Describe alternatives you've considered Implement the GPU kernels for wavelet transforms. But this would be less optimized than cuDNN convolutions.