Closed xuesongnie closed 1 year ago
Yes, I think this came up in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/issues/40 .
You got this. The 2D transform does not touch the leading channel dimension. Quoting from the old discussion:
import torch, pywt, ptwt import numpy as np from scipy.datasets import face # get a colour image and move the channels into the batch dimension. face = face()[256 : 512, 256 : 512] face_tp = np.transpose( face, [2, 0, 1] ).astype(np.float64) # compute the transform coeff2d = ptwt.wavedec2(torch.from_numpy(face_tp), pywt.Wavelet("Haar"), level=1, mode='zero') # invert the transform rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar")) # move the color channels back face_rec = rec.squeeze().permute(1,2,0) np.allclose(face, face_rec)
It's a great job! Questions about input dimensions: Q1: To apply 2D DWT to an image with [B, C, H, W], I have to combine BC (eg. [BC, H, W]) to use ptwt.wavedec2 ? Q2: To apply 3D DWT to a video sequence with [B, T, C, H, W], I have to combine BC (eg. [BC, T, H, W]) to use ptwt.wavedec3 ?
In Ptwt: wavedec2: data (torch.Tensor): The input data tensor with up to three dimensions.
2d inputs are interpreted as [height, width],
3d inputs are interpreted as [batch_size, height, width].
wavedec3: data (torch.Tensor): The input data of shape
[batch_size, length, height, width]