fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

How to use wavelets for channel last #14

Closed thechidiya closed 4 years ago

thechidiya commented 4 years ago

Hi,

I'm trying to use wavelets in my CNN network. I'm replacing max-pooling with wavelet for downsampling. I stumble upon your code. I input my image tensor in (H, W, channel) but your code is channel first. Could we do this with your code?

fbcotter commented 4 years ago

It's a little tricky to change the code to use channel last. I don't even know if pytorch supports it. Perhaps it's best if you did the following:

import torch
from pytorch_wavelets import DWT
class DWT_NHWC(DWT): 
    def forward(self, x): 
        # make input NCHW for pytorch_wavelets 
        x = x.permute(0, 3, 1, 2) 
        yl, yh = super().forward(x) 
        # Cast lowpass back to NHWC 
        yl = yl.permute(0, 2, 3, 1) 
        # Cast bandpasses to N3HWC 
        yh = [subband.permute(0, 2, 3, 4, 1) for subband in yh] 
        return yl, yh

# input in NHWC format
x = torch.randn(1, 16, 16, 3)
dwt = DWT_NHWC(J=2)
yl, yh = dwt(x)
thechidiya commented 4 years ago

Great! Thank you!! I'll check it!