Closed thechidiya closed 4 years ago
It's a little tricky to change the code to use channel last. I don't even know if pytorch supports it. Perhaps it's best if you did the following:
import torch
from pytorch_wavelets import DWT
class DWT_NHWC(DWT):
def forward(self, x):
# make input NCHW for pytorch_wavelets
x = x.permute(0, 3, 1, 2)
yl, yh = super().forward(x)
# Cast lowpass back to NHWC
yl = yl.permute(0, 2, 3, 1)
# Cast bandpasses to N3HWC
yh = [subband.permute(0, 2, 3, 4, 1) for subband in yh]
return yl, yh
# input in NHWC format
x = torch.randn(1, 16, 16, 3)
dwt = DWT_NHWC(J=2)
yl, yh = dwt(x)
Great! Thank you!! I'll check it!
Hi,
I'm trying to use wavelets in my CNN network. I'm replacing max-pooling with wavelet for downsampling. I stumble upon your code. I input my image tensor in (H, W, channel) but your code is channel first. Could we do this with your code?