v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
European Union Public License 1.2
279 stars 36 forks source link

my inverse stationary wavelets transform (ISWT) implement get wrong results, is there any problem? #74

Closed yutian-wang closed 6 months ago

yutian-wang commented 8 months ago

Thanks to the author for his contribution, this project is great! My current work requires swt and iswt. I see that the author has provided an experimental swt code, but no code for iswt yet. I wrote a code for iswt modeled after _swt() and waverec() functions. It can generate a result with right shape, but it is different from the input of _swt(). May I ask where is the problem?

def _iswt(
    coeffs: List[torch.Tensor],
    wavelet: Union[Wavelet, str],
    level: Optional[int] = None,
) -> torch.Tensor:

    torch_device = coeffs[0].device
    torch_dtype = coeffs[0].dtype

    for coeff in coeffs[1:]:
        if torch_device != coeff.device:
            raise ValueError("coefficients must be on the same device")
        elif torch_dtype != coeff.dtype:
            raise ValueError("coefficients must have the same dtype")

    _, _, rec_lo, rec_hi = _get_filter_tensors(
        wavelet, flip=False, device=torch_device, dtype=torch_dtype

    filt_len = rec_lo.shape[-1]
    filt = torch.stack([rec_lo, rec_hi], 0)

    res_lo = coeffs[0]
    for cpos, res_hi in enumerate(coeffs[1:]):
        dilation = 2**cpos
        res_lo = torch.stack([res_lo, res_hi], 1)
        res_lo = torch.nn.functional.conv_transpose1d(res_lo, filt, stride=1,dilation=dilation)
        # remove the padding
        padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)

        if padl > 0:
            res_lo = res_lo[..., padl:]
        if padr > 0:
            res_lo = res_lo[..., :-padr]                                                                                                                

    return res_lo
v0lta commented 8 months ago

That's a good question. SWT support is not official because I never finished the iswt code. I will have to take a look and get back to you.

v0lta commented 8 months ago

You are facing the problem that conv_transpose fails to undo the dilation. I think this is where CNN's and Wavelet's transform requirements differ. If we choose the filters via gradient descent, we do not require exact inversion of dilation. However, for a wavelet toolbox like this, we do. https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/73/commits/2520a9175f2de8a9911b9d5a5e92b4eba25b7192 takes care of this.

v0lta commented 8 months ago

It's not finished. Tests for level=None aren't passing yet. But if you choose a level argument, it will work. Feel free to take a look via:

pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@improved-docs
yutian-wang commented 8 months ago

Thank you for your answer, it helps me a lot!

yutian-wang commented 8 months ago

I tested your new code, there is a issue. If the input data batch >1, the code would return error. The error is caused by _conv_transpose_dedilate. For example, if I input [(3,1024), (3,1024), (3,1024)] in iswt(),the first loop return (1,3072), and the second loop return error because shape mismatch between (1,3072) and (3,1024). obviously, the ideal return of first loop is (3,1024). But I haven't totally understand your code, so I don't know how to repair it.

v0lta commented 8 months ago

Ahh, yes I see, the tests also currently don't cover batched inputs. Good catch, I will look into this!

v0lta commented 8 months ago

37d0d31 no longer has the batch problem. However, level-argument support still needs more work. Please let me know if this solves your problem.

yutian-wang commented 8 months ago

Thank you very much for your contribution.! I tested that your code is correct when batch>1. I've been researching on this code for the past few days: https://github.com/qgpmztmf/Stationary_Wavelet_Transform_PyTorch/blob/master/SWT.py I found that it uses F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation) to implement 2D ISWT. It may be possible to use the groups parameter to make the computation more efficient. In your code, to_conv_t_list is implemented with iterations. I haven't checked this carefully. But maybe it might help you. Anyway, Thank you again!