fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

How to make the input and output consistent? #22

Closed ma3252788 closed 3 years ago

ma3252788 commented 3 years ago

I found that if the size is even, the input and output are the same, if it is odd, it is different.

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(2, 2048, 26, 32)
Yl, Yh = xfm(X.cuda())
ifm = DTCWTInverse().cuda()
Y = ifm((Yl, Yh))
print(Y.shape)

I got: torch.Size([2, 2048, 26, 32])

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(2, 2048, 25, 31)
Yl, Yh = xfm(X.cuda())
ifm = DTCWTInverse().cuda()
Y = ifm((Yl, Yh))
print(Y.shape)

I got: torch.Size([2, 2048, 26, 32])

fbcotter commented 3 years ago

Yes this is expected due to the subsampling nature of the decimated wavelet transforms, it works best with power of 2 inputs. You can throw away the last row and column of the output to get your original input.

ma3252788 commented 3 years ago

Yes this is expected due to the subsampling nature of the decimated wavelet transforms, it works best with power of 2 inputs. You can throw away the last row and column of the output to get your original input.

OK, I see. I'll try. Thank you for your reply~