v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Insufficient padding removal in wavelet packet reconstruction. #52

Closed felixblanke closed 1 year ago

felixblanke commented 1 year ago

I noticed that the removal of added padding is not done properly in our Wavelet packet reconstruction.

The following snippet calculates the reconstruction sizes for all valid combinations of

import ptwt, torch
def print_reconstruction_sizes(size: int, wavelet="db4", level=3):
    data = torch.eye(size, device="cuda", dtype=torch.float32)
    reconstructions = []
    reconstructions.append(ptwt.waverec2(ptwt.wavedec2(data=data, wavelet=wavelet, level=level, mode="reflect"), wavelet=wavelet))
    reconstructions.append(ptwt.fswaverec2(ptwt.fswavedec2(data, wavelet=wavelet, mode="reflect", level=level), wavelet=wavelet))
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="reflect").reconstruct()[""])
    reconstructions.append(ptwt.MatrixWaverec2(wavelet=wavelet)(ptwt.MatrixWavedec2(wavelet=wavelet, level=level)(data)))
    reconstructions.append(ptwt.MatrixWaverec2(wavelet=wavelet, separable=False)(ptwt.MatrixWavedec2(wavelet=wavelet, level=level, separable=False)(data)))
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="boundary", separable=False).reconstruct()[""])
    reconstructions.append(ptwt.WaveletPacket2D(data=data, wavelet=wavelet, maxlevel=level, mode="boundary", separable=True).reconstruct()[""])
    print(f"reconstruction sizes for size {size}: {[rec.shape[-1] for rec in reconstructions]}")

Running it on the current v0.1.5 dev branch for some sample input sizes yields:

for size in range(42, 80, 2):
    print_reconstruction_sizes(size)
Output

``` reconstruction sizes for size 42: [42, 42, 46, 42, 42, 48, 48] reconstruction sizes for size 44: [44, 44, 46, 44, 44, 48, 48] reconstruction sizes for size 46: [46, 46, 46, 46, 46, 48, 48] reconstruction sizes for size 48: [48, 48, 54, 48, 48, 48, 48] reconstruction sizes for size 50: [50, 50, 54, 50, 50, 56, 56] reconstruction sizes for size 52: [52, 52, 54, 52, 52, 56, 56] reconstruction sizes for size 54: [54, 54, 54, 54, 54, 56, 56] reconstruction sizes for size 56: [56, 56, 62, 56, 56, 56, 56] reconstruction sizes for size 58: [58, 58, 62, 58, 58, 64, 64] reconstruction sizes for size 60: [60, 60, 62, 60, 60, 64, 64] reconstruction sizes for size 62: [62, 62, 62, 62, 62, 64, 64] reconstruction sizes for size 64: [64, 64, 70, 64, 64, 64, 64] reconstruction sizes for size 66: [66, 66, 70, 66, 66, 72, 72] reconstruction sizes for size 68: [68, 68, 70, 68, 68, 72, 72] reconstruction sizes for size 70: [70, 70, 70, 70, 70, 72, 72] reconstruction sizes for size 72: [72, 72, 78, 72, 72, 72, 72] reconstruction sizes for size 74: [74, 74, 78, 74, 74, 80, 80] reconstruction sizes for size 76: [76, 76, 78, 76, 76, 80, 80] reconstruction sizes for size 78: [78, 78, 78, 78, 78, 80, 80] ```

felixblanke commented 1 year ago

The underlying problem is that the packet reconstruction uses only the child nodes to reconstruct the parent node coefficients. Currently, the reconstruction fully ignores the entries of the parent node or its root nodes and solely depends on the child node entries. So for the current implementation, if e.g. one was given the wavelet packet coefficients of a specific level one could initialize the leaf nodes of a wavelet packet tree and the maxlevel accordingly and then reconstruct the original signal directly.

a = ptwt.WaveletPacket2D(data=data, wavelet="db4", maxlevel=3)
b = ptwt.WaveletPacket2D(data=None, wavelet="db4") # new empty Wavelet Packet Tree
for node in a.get_natural_order(a.maxlevel):
    b[node] = a[node]
b.maxlevel = a.maxlevel
b.reconstruct()[""] # Reconstructed signal

However, this construction cannot remove the padding that is applied in the wavelet decomposition as it has no shape to compare the reconstructed coefficients against.

I implemented a fix in 16a74f9dda39fb88bfd0714a117a35c8006a1c13: This compares the reconstructed shape against the coefficient shape that was previously stored in the node. This allows padding removal, but now an initializing forward pass is always needed. So the example above becomes:

a = ptwt.WaveletPacket2D(data=data, wavelet="db4", maxlevel=3)
# now apply a forward pass with the correct input shape at construction
b = ptwt.WaveletPacket2D(data=torch.zeros_like(data), wavelet="db4", maxlevel=a.maxlevel)
for node in a.get_natural_order(a.maxlevel):
    b[node] = a[node]
b.reconstruct()[""] # Reconstructed signal
felixblanke commented 1 year ago

For consistency with the FWT implementation, added padding on level 0 is not removed: The FWTs have no coefficients to compare the level zero reconstruction against, so added padding making the signal size even is not removed. However for the wavelet packets, we could compare against the root node to remove such padding, i.e. returning a reconstruction with an odd length if the input signal has an odd length. To do this, we could simply remove the checks if level > 0 from 16a74f9 (e.g. this one). Any thoughts on this @v0lta ?

v0lta commented 1 year ago

It's super lovely you fixed this. I thought not always removing all of the padding was ok. We have the same problem with wavedec and waverec for odd inputs. Since waverec can't know what wavedec did it leaves the odd padding and users who know the original input shape have to remove it themselves. How we do it there is consistent with pywt.

However most users will want the padding removed, so I am guessing hardcoding it is ok. I can't think of any unintended side effects this could have.

v0lta commented 1 year ago

So thanks for making the toolbox better!