v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

my inverse stationary wavelets transform (ISWT) implement get wrong results, is there any problem? #74

Closed yutian-wang closed 6 months ago

yutian-wang commented 8 months ago

Thanks to the author for his contribution, this project is great! My current work requires swt and iswt. I see that the author has provided an experimental swt code, but no code for iswt yet. I wrote a code for iswt modeled after _swt() and waverec() functions. It can generate a result with right shape, but it is different from the input of _swt(). May I ask where is the problem?

def _iswt(
    coeffs: List[torch.Tensor],
    wavelet: Union[Wavelet, str],
    level: Optional[int] = None,
) -> torch.Tensor:

    torch_device = coeffs[0].device
    torch_dtype = coeffs[0].dtype

    for coeff in coeffs[1:]:
        if torch_device != coeff.device:
            raise ValueError("coefficients must be on the same device")
        elif torch_dtype != coeff.dtype:
            raise ValueError("coefficients must have the same dtype")

    _, _, rec_lo, rec_hi = _get_filter_tensors(
        wavelet, flip=False, device=torch_device, dtype=torch_dtype
    )

    filt_len = rec_lo.shape[-1]
    filt = torch.stack([rec_lo, rec_hi], 0)

    res_lo = coeffs[0]
    for cpos, res_hi in enumerate(coeffs[1:]):
        dilation = 2**cpos
        res_lo = torch.stack([res_lo, res_hi], 1)
        res_lo = torch.nn.functional.conv_transpose1d(res_lo, filt, stride=1,dilation=dilation)
        # remove the padding
        padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)

        if padl > 0:
            res_lo = res_lo[..., padl:]
        if padr > 0:
            res_lo = res_lo[..., :-padr]                                                                                                                

    return res_lo
v0lta commented 8 months ago

That's a good question. SWT support is not official because I never finished the iswt code. I will have to take a look and get back to you.

v0lta commented 8 months ago

You are facing the problem that conv_transpose fails to undo the dilation. I think this is where CNN's and Wavelet's transform requirements differ. If we choose the filters via gradient descent, we do not require exact inversion of dilation. However, for a wavelet toolbox like this, we do. https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/73/commits/2520a9175f2de8a9911b9d5a5e92b4eba25b7192 takes care of this.

v0lta commented 8 months ago

It's not finished. Tests for level=None aren't passing yet. But if you choose a level argument, it will work. Feel free to take a look via:

pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@improved-docs
yutian-wang commented 8 months ago

Thank you for your answer, it helps me a lot!

yutian-wang commented 8 months ago

I tested your new code, there is a issue. If the input data batch >1, the code would return error. The error is caused by _conv_transpose_dedilate. For example, if I input [(3,1024), (3,1024), (3,1024)] in iswt(),the first loop return (1,3072), and the second loop return error because shape mismatch between (1,3072) and (3,1024). obviously, the ideal return of first loop is (3,1024). But I haven't totally understand your code, so I don't know how to repair it.

v0lta commented 8 months ago

Ahh, yes I see, the tests also currently don't cover batched inputs. Good catch, I will look into this!

v0lta commented 8 months ago

37d0d31 no longer has the batch problem. However, level-argument support still needs more work. Please let me know if this solves your problem.

yutian-wang commented 8 months ago

Thank you very much for your contribution.! I tested that your code is correct when batch>1. I've been researching on this code for the past few days: https://github.com/qgpmztmf/Stationary_Wavelet_Transform_PyTorch/blob/master/SWT.py I found that it uses F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation) to implement 2D ISWT. It may be possible to use the groups parameter to make the computation more efficient. In your code, to_conv_t_list is implemented with iterations. I haven't checked this carefully. But maybe it might help you. Anyway, Thank you again!