v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Questions about input dimensions #57

Closed xuesongnie closed 1 year ago

xuesongnie commented 1 year ago

It's a great job! Questions about input dimensions: Q1: To apply 2D DWT to an image with [B, C, H, W], I have to combine BC (eg. [BC, H, W]) to use ptwt.wavedec2 ? Q2: To apply 3D DWT to a video sequence with [B, T, C, H, W], I have to combine BC (eg. [BC, T, H, W]) to use ptwt.wavedec3 ?

In Ptwt: wavedec2: data (torch.Tensor): The input data tensor with up to three dimensions.
2d inputs are interpreted as [height, width],
3d inputs are interpreted as [batch_size, height, width].
wavedec3: data (torch.Tensor): The input data of shape
[batch_size, length, height, width]

v0lta commented 1 year ago

Yes, I think this came up in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/issues/40 .

v0lta commented 1 year ago

You got this. The 2D transform does not touch the leading channel dimension. Quoting from the old discussion:

import torch, pywt, ptwt
import numpy as np
from scipy.datasets import face
# get a colour image and move the channels into the batch dimension.
face = face()[256 : 512, 256 : 512]
face_tp = np.transpose(
       face, [2, 0, 1]
).astype(np.float64)
# compute the transform
coeff2d = ptwt.wavedec2(torch.from_numpy(face_tp), pywt.Wavelet("Haar"),
                                         level=1, mode='zero')
# invert the transform
rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar"))
# move the color channels back
face_rec = rec.squeeze().permute(1,2,0)

np.allclose(face, face_rec)