Closed jakubMitura14 closed 1 year ago
Sounds like a cool project. Let's do it for v0.0.8!
Development happens in https://github.com/v0lta/Jax-Wavelet-Toolbox/pull/4 I am happy to review pull requests. I have to teach next month. This means I will likely be looking into this around late April.
Fantastic! - thanks ! I will probably have some questions earlier but I would not expect to have them answered before you will be free.
Hello I understand that you may have no time for it currently, just want to tell about my progress:
My primary goal is to put it into optimization loop (that is why I need differentiability ), so the algorithm will learn coefficients - on its basis will generate texture. As texture has repeated patterns (histologic organization of the tissue) it should in principle possible to describe it using a relatively small amount of wavelets. As tissue organization is invariant to the translation I was first looking into stationary wavelet transform - but dual-tree complex wavelet transform is even better - it gives approximate translation invariance and is not invariant to the direction (exactly like tissue). Hence I started implementing this.
I created a separate repository [1] it has its fat dockerfile and devcontainer for vs code containers but basically, it requires jax numpy and einops. Repository is based on [2] The very simple test jax_dtctw.py on the basis of the original code it should give sth like 4.4 What is done: I translated all original code from numpy to jax - it gves results as far as limited tests showed similar to the original. I started some optimizations and readability improvements - for example adding einops I started working on making the algorithm differentiable - for example changing fmod into its differentiable version What is not done Although I started working on jit it does not work yet - some problems with nonstatic array shapes. A lot of optimizations are possible for example I found dtcwt in PyTorch [3] with a lot of improvements (However it is only 2D and not as easy to parse to Jax as the previous version).
In my use case, I need to generate an image bigger than the original sample - texture synthesis. I suppose it should be simple with wavelets but I do not know how to do this.
1) https://github.com/jakubMitura14/dtcwt_jax 2) https://github.com/rjw57/dtcwt/tree/master/dtcwt/numpy 3) https://github.com/fbcotter/pytorch_wavelets/blob/master/pytorch_wavelets/dtcwt/coeffs.py
Dear @jakubMitura14 , Thanks for letting me know. I will take a look at your code next month. Since we don't have a dtctw in jax yet, your project is very interesting.
Fantastic!
https://github.com/v0lta/Jax-Wavelet-Toolbox/pull/7 adds 1d swt support (https://jax-wavelet-toolbox.readthedocs.io/en/latest/jaxwt.html#module-jaxwt.stationary_transform ). Future releases will add more.
Fantastic !
Thanks
Hello I have some jax pipeline where I would need to use 3d Stationary Wavelet Transform and its deterministic inverse. As far as I see It is not yet implemented in jax. Do you plan to extend your package to add such feature? I could contribute to the effort if you do