fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

Up/Downsampling Questions #24

Closed kierkegaard13 closed 3 years ago

kierkegaard13 commented 3 years ago

Hello,

I don't have too much of a signal processing background, so apologies for being uninformed here. I'm trying to replicate the recent Adaptive Discriminator Augmentation code from the StyleGAN team in Pytorch, and it seems like they're using the sym6 wavelet to upsample an image before applying a projective matrix to an image and then downsampling using the same wavelet. They use tensorflows tf.nn.depthwise_conv2d_backprop_input function to achieve this. I can't seem to find much literature on up and downsampling via wavelets, but the most I could gather was that the DWT could be used to accomplish this.

I tried using the DWT in this library to upsample by first initializing a matrix of zeros with the upsample dimensions and filling in the upper left quadrant with the original image and then applying the DWT (according to the instructions from one paper I came across by Microsoft), but I'm not sure how to use the output to reconstruct the upsampled image.

The input image dimensions are [1, 3, 128, 128], and I want the upsample dimensions to be [1, 3, 256, 256]. After applying the DWT with J=3 and wave='sym6', I get yl with dimensions [1, 3, 41, 41], and yh with dimensions [1, 3, 3, 133, 133], [1, 3, 3, 72, 72], [1, 3, 3, 41, 41]. Is there a way to assemble that output into an upsampled image or am I doing something fundamentally wrong here? And how might I downsample the resulting upsampled image after applying a transformation to it? Would I just pass (yl2, yh) to the inverse transform?

Any help here is much appreciated. Thanks.

fbcotter commented 3 years ago

I'm not quite sure what you want to replicate, perhaps you can link me to the tf code that does the upsampling?

Let's start from a different angle. Say if you have a target size, [1, 3, 256, 256], and use the DWT to downsample it to get a pyramid of signals:

import torch
from pytorch_wavelets import DWT, IDWT

x = torch.randn(1, 3, 256, 256)
dwt = DWT(J=1, wave='sym6', mode='zero')
yl, yh = dwt(x)
yl.shape
>>> torch.Size([1, 3, 133, 133])
yh[0].shape
>>> torch.Size([1, 3, 3, 133, 133])

What's happened here is yl is obtained by convolving x with the lowpass analysis filters h0. For a sym6 wavelet, these have length 12 (try printing dwt.h0_col or dwt.h0_row). When convolving a 256 length signal with a 12 length filter, we get an output of (256 + 12 - 1) = 267. This is then downsampled by 2, giving us the 133 we see in yl. The same is done for the 3 bandpass filters, just with combinations of h0 and h1 for rows and columns.

If you want to upsample a signal by 2 but don't have access to the bandpass coefficients, all we are doing is taking this yl, adding zeros in every other sample, and smoothing with a lowpass filter. It's not too much different to just using bilinear interpolation (although the lowpass filter will have a different frequency response curve).

To do this in this case, you would drop your signal into the yl coefficients. Note that after the initial convolution at the higher sample rate, the 256 length signal gained 5 samples before and 6 samples after. This became 2 samples before and 3 samples after when downsampled. Then:

idwt = IDWT(wayve='sym6', mode='zero')
yh = torch.zeros(1, 3, 3, 133, 133) # Should be the correct shape
yl = torch.zeros(1, 3, 133, 133)
yl[:, :, 2:-3, 2:-3] = your_signal
upsampled = idwt((yl, yh))
kierkegaard13 commented 3 years ago

The code I'm specifically trying to replicate is here: https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py#L424. Here's the paper if you're interested: https://arxiv.org/abs/2006.06676.

Thanks for that explanation. I think that makes sense for the most part. I'm not sure I get the padding scheme though and why the 256 length signal gains pixels after the convolution. I would expect the signal to be padded and then lose 11 pixels after the convolution with the sym6 signal. The upsampling procedure makes complete sense, but if I wanted to downsample to exactly half the pixels I'm not 100% sure how to do that.

fbcotter commented 3 years ago

Ok I see what they're doing in their implementation. Re padding: this is a property of convolutions, check out the gif here where two square pulses are convolved: https://en.wikipedia.org/wiki/Convolution#/media/File:Convolution_of_box_signal_with_itself2.gif, the output triangle signal has support before and after the input square.

In typical conv networks, we just throw away the extra info before and after. This is usually not a problem as the filters h are usually very small (3x3) compared to the signal x, so it's just easier to discard this.

To handle the padding in the DWT setting, it's a little tricky, but not too complex. What you want to do is:

x # shape [1, 3, 128, 128]
idwt = IDWT(wave='sym6', mode='zero')
dwt = DWT(J=1, wave='sym6', mode='zero')
# pad x to make it the expected 133x133 shape before giving to the IDWT - Use zeros to do this
upsampled = idwt((torch.nn.functional.pad(x, (2, 3, 2, 3)), [None, ]))  # shape [1, 3, 256, 256]
...
# The decimated version will have output 133x133 - discard the extra regions. These will not be zero, a bit like the
# triangle output from the gif.
downsampled = dwt(upsampled)[0][..., 2:-3, 2:-3]

If you then compare downsampled and x, they will be identical, except some errors at the borders, again a complication of the padding.

kierkegaard13 commented 3 years ago

Great, thanks. That makes perfect sense. Out of curiosity, after looking at their implementation, do you think there are differences in their approach with what you've explained here at all? The thing I was most confused was the function tf.nn.depthwise_conv2d_backprop_input, because I couldn't find any example usage and it seemed like it should return gradient information from the description in the docs.

kierkegaard13 commented 3 years ago

I tested out the up/downsampling method you mentioned and it seems to work perfectly. I'm using pytorch-geometry to apply the projective transformation and it seems to replicate everything in the paper as far as I can tell. Thanks again for the help.