v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

WaveletPacket2D yields incoherent results for certain data shape/wavelet level combinations #21

Closed RaoulHeese closed 2 years ago

RaoulHeese commented 2 years ago

Calling ptwt.WaveletPacket2D (and processing its outputs) can lead to not self-explanatory exceptions or unexpected results depending on the image size and the wavelet level. Evaluating the following code snippet demonstrates this behavior based on six simple tests with a test image.

import numpy as np
from PIL import Image
import torch
import pywt
import ptwt
from itertools import product

def generate_PIL_img(size=256, channels=3):
    # generate demo image
    img = np.zeros((size, size, channels), dtype=np.uint8)
    img[::size//8] = 255
    img[:,::size//8] = 255
    return Image.fromarray(img)

def wl_transform(img, img_size, max_lev, wavelet_str = "db5", mode = "reflect"):
    # wavelet transform with pipeline: PIL -> numpy -> torch -> numpy
    img = img.resize((img_size,img_size))
    image_batch = np.array(img)[None,:]
    image_batch_tensor = torch.from_numpy(image_batch.astype(np.float32))

    wavelet = pywt.Wavelet(wavelet_str)
    wp_keys = ["".join(node) for node in product(["a", "h", "v", "d"], repeat=max_lev)]

    channels = []
    for channel in range(image_batch_tensor.shape[-1]):
        with torch.no_grad():
            pt_data = image_batch_tensor[:, :, :, channel]
            ptwt_wp_tree = ptwt.WaveletPacket2D(data=pt_data, wavelet=wavelet, mode=mode)
            packet_list = []
            for node in wp_keys:
                packet = torch.squeeze(ptwt_wp_tree["".join(node)], dim=1)
                packet_list.append(packet)
            channel_packets = torch.stack(packet_list, dim=1)
        channels.append(channel_packets)
    packets = torch.stack(channels, -1)

    return packets.numpy()

# TEST 1
# Ok 
wav = wl_transform(generate_PIL_img(), 128, 2)
assert len(wav.shape) == 5

# TEST 2
# Exception 1: AssertionError
wav = wl_transform(generate_PIL_img(), 128, 3)
assert len(wav.shape) == 5 # wav.shape = (23, 64, 23, 3)

# TEST 3
# Ok 
wav = wl_transform(generate_PIL_img(), 256, 3)
assert len(wav.shape) == 5

# TEST 4
# Exception 2: KeyError
wav = wl_transform(generate_PIL_img(), 128, 4) # 'aaaa' not found in ptwt_wp_tree
assert len(wav.shape) == 5

# TEST 5
# Exception 3: AssertionError 
wav = wl_transform(generate_PIL_img(), 256, 4)
assert len(wav.shape) == 5 # wave.shape = (24, 256, 24, 3)

# TEST 6
# Ok 
wav = wl_transform(generate_PIL_img(), 512, 4)
assert len(wav.shape) == 5

The expected behavior would be an informative exception on wrong usage or a valid output otherwise.

v0lta commented 2 years ago

Good catch, I think I fixed the problem in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/commit/fb31c383c1472a0cb328f5c72d7327afec9aca9f . We had a bare squeeze in the code, which kills the batch_dimension if it is one. I will test this some more.

v0lta commented 2 years ago

Okay, I think the patch is ready: https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/23 deals with this.

v0lta commented 2 years ago

Regarding the key error, we want a warning, we already have some for the matrix fwts https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/16 .

v0lta commented 2 years ago

25 fixes this.