v0lta / Jax-Wavelet-Toolbox

Differentiable and gpu enabled fast wavelet transforms in JAX.
European Union Public License 1.2
40 stars 2 forks source link

Stationary 3d wavelet transform #3

Closed jakubMitura14 closed 1 year ago

jakubMitura14 commented 1 year ago

Hello I have some jax pipeline where I would need to use 3d Stationary Wavelet Transform and its deterministic inverse. As far as I see It is not yet implemented in jax. Do you plan to extend your package to add such feature? I could contribute to the effort if you do

v0lta commented 1 year ago

Sounds like a cool project. Let's do it for v0.0.8!

v0lta commented 1 year ago

Development happens in https://github.com/v0lta/Jax-Wavelet-Toolbox/pull/4 I am happy to review pull requests. I have to teach next month. This means I will likely be looking into this around late April.

jakubMitura14 commented 1 year ago

Fantastic! - thanks ! I will probably have some questions earlier but I would not expect to have them answered before you will be free.

jakubMitura14 commented 1 year ago

Hello I understand that you may have no time for it currently, just want to tell about my progress:

My primary goal is to put it into optimization loop (that is why I need differentiability ), so the algorithm will learn coefficients - on its basis will generate texture. As texture has repeated patterns (histologic organization of the tissue) it should in principle possible to describe it using a relatively small amount of wavelets. As tissue organization is invariant to the translation I was first looking into stationary wavelet transform - but dual-tree complex wavelet transform is even better - it gives approximate translation invariance and is not invariant to the direction (exactly like tissue). Hence I started implementing this.

I created a separate repository [1] it has its fat dockerfile and devcontainer for vs code containers but basically, it requires jax numpy and einops. Repository is based on [2] The very simple test jax_dtctw.py on the basis of the original code it should give sth like 4.4 What is done: I translated all original code from numpy to jax - it gves results as far as limited tests showed similar to the original. I started some optimizations and readability improvements - for example adding einops I started working on making the algorithm differentiable - for example changing fmod into its differentiable version What is not done Although I started working on jit it does not work yet - some problems with nonstatic array shapes. A lot of optimizations are possible for example I found dtcwt in PyTorch [3] with a lot of improvements (However it is only 2D and not as easy to parse to Jax as the previous version).

In my use case, I need to generate an image bigger than the original sample - texture synthesis. I suppose it should be simple with wavelets but I do not know how to do this.

1) https://github.com/jakubMitura14/dtcwt_jax 2) https://github.com/rjw57/dtcwt/tree/master/dtcwt/numpy 3) https://github.com/fbcotter/pytorch_wavelets/blob/master/pytorch_wavelets/dtcwt/coeffs.py

v0lta commented 1 year ago

Dear @jakubMitura14 , Thanks for letting me know. I will take a look at your code next month. Since we don't have a dtctw in jax yet, your project is very interesting.

jakubMitura14 commented 1 year ago

Fantastic!

v0lta commented 1 year ago

https://github.com/v0lta/Jax-Wavelet-Toolbox/pull/7 adds 1d swt support (https://jax-wavelet-toolbox.readthedocs.io/en/latest/jaxwt.html#module-jaxwt.stationary_transform ). Future releases will add more.

jakubMitura14 commented 1 year ago

Fantastic !

jakubMitura14 commented 1 year ago

Thanks