Closed yutian-wang closed 6 months ago
That's a good question. SWT support is not official because I never finished the iswt code. I will have to take a look and get back to you.
You are facing the problem that conv_transpose fails to undo the dilation. I think this is where CNN's and Wavelet's transform requirements differ. If we choose the filters via gradient descent, we do not require exact
inversion of dilation. However, for a wavelet toolbox like this, we do. https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/73/commits/2520a9175f2de8a9911b9d5a5e92b4eba25b7192 takes care of this.
It's not finished. Tests for level=None
aren't passing yet. But if you choose a level argument, it will work.
Feel free to take a look via:
pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@improved-docs
Thank you for your answer, it helps me a lot!
I tested your new code, there is a issue. If the input data batch >1, the code would return error. The error is caused by _conv_transpose_dedilate
. For example, if I input [(3,1024), (3,1024), (3,1024)] in iswt()
,the first loop return (1,3072), and the second loop return error because shape mismatch between (1,3072) and (3,1024). obviously, the ideal return of first loop is (3,1024). But I haven't totally understand your code, so I don't know how to repair it.
Ahh, yes I see, the tests also currently don't cover batched inputs. Good catch, I will look into this!
37d0d31 no longer has the batch problem. However, level-argument support still needs more work. Please let me know if this solves your problem.
Thank you very much for your contribution.! I tested that your code is correct when batch>1.
I've been researching on this code for the past few days: https://github.com/qgpmztmf/Stationary_Wavelet_Transform_PyTorch/blob/master/SWT.py
I found that it uses F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation)
to implement 2D ISWT. It may be possible to use the groups
parameter to make the computation more efficient. In your code, to_conv_t_list
is implemented with iterations. I haven't checked this carefully. But maybe it might help you. Anyway, Thank you again!
Thanks to the author for his contribution, this project is great! My current work requires swt and iswt. I see that the author has provided an experimental swt code, but no code for iswt yet. I wrote a code for iswt modeled after
_swt()
andwaverec()
functions. It can generate a result with right shape, but it is different from the input of_swt()
. May I ask where is the problem?