fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

DWTInverse is not the inverse of DWTForward for small image sizes. #42

Open FlorentinGuth opened 2 years ago

FlorentinGuth commented 2 years ago

For image sizes smaller than the wavelet support, DWTInverse is not the inverse operation of DWTForward (probably due to border effects).

How to reproduce

With pytorch_wavelets 1.3.0 and python 3.8.11:

from pytorch_wavelets import DWTForward, DWTInverse

def wavelet_check(L):
    n = lambda x: torch.norm(torch.flatten(x))
    rel_err = lambda true, other: (n(true - other) / n(true)).item()

    forward = DWTForward(wave="db4", mode="periodization")
    inverse = DWTInverse(wave="db4", mode="periodization")

    N = 10
    x = torch.rand((N, 4, L, L))
    y = inverse((x[:, :1], [x[:, None, 1:4]]))
    l, h = forward(y)
    x_rec = torch.concat((l, h[0][:, 0, ...]), dim=1)
    print(f"L={L}, relative error {rel_err(x, x_rec):.2}")

for j in range(5):
    wavelet_check(2 ** j)

produces the following output:

L=1, relative error 1.0
L=2, relative error 0.88
L=4, relative error 1.2e-07
L=8, relative error 1.2e-07
L=16, relative error 1.2e-07