PyWavelets / pywt

PyWavelets - Wavelet Transforms in Python
http://pywavelets.readthedocs.org
MIT License
2.04k stars 470 forks source link

Pywavelets GPU version #593

Open miguelcarcamov opened 3 years ago

miguelcarcamov commented 3 years ago

Hi everyone,

First of all, let me say that this library is awesome. I have seen how this has evolved and at the moment the library is great. I am currently building a software for image synthesis and I would really like to include wavelets on it. That said, this software will be really big computing and big data oriented. In that sense, have you thought on including dask or numba on it? Is in your plans? Has already been done?

Cheers

v0lta commented 3 years ago

I am working on a Pytorch wavelet toolbox, with the aim of full pywt compatibility here: https://github.com/v0lta/PyTorch-Wavelet-Toolbox . Pytorch supports efficient GPU processing, so you should be able to use my work to move some of your Wavelet computations onto a GPU.

grlee77 commented 3 years ago

@v0lta that sounds nice.

I actually have a library with a CuPy-based implementation that had a number of additional discrete wavelet transforms (like dual-tree wavelets, etc.). I was previously using it for research purposes and hadn't had time to make proper docs, etc, but should probably just open a new repository and mark it as beta/experimental until it gets wider testing.

The standard DWT transforms in that library were built using this GPU equivalent to SciPy's scipy.signal.upfirdn: https://github.com/mritools/fast_upfirdn.

grlee77 commented 3 years ago

Another incentive to get this repo up would be to enable a GPU-based denoise_wavelet from the scikit-image API for cuCIM

shailesh1729 commented 3 years ago

I am working on a JAX-based port of PyWavelet. It is part of the CR-Sparse library. The cr.sparse.wt module currently has an implementation of dwt, idwt, upcoef, downcoef, wavedec, waverec, dwt2, idwt2, etc. See the API here. JAX is based on XLA. It thus allows just-in-time compilation of code for GPUs and TPUs with high performance. I have been able to build support for all the discrete wavelet families so far.

It's still the early stage prototype. On the CPU, the implementation runs somewhat slower than PyWavelets. I haven't benchmarked it yet on GPUs. The API closely resembles PyWavelets. I have been able to port several unit tests from PyWavelets and they are running fine. The documentation so far is sparse.