fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
900 stars 139 forks source link

using dtcwt as a high freq feature extractor #41

Closed jamesrobertwilliams closed 2 years ago

jamesrobertwilliams commented 2 years ago

I would like to extract high frequency components of an image and I am wondering how I can extract this using dtcwt. So, I looked at the examples, and I presume I can do something like this:

import torch, sys
from pytorch_wavelets import DTCWTForward
xfm1 = DTCWTForward(J=3)
x = torch.randn(1, 3, 64, 64) # 3 channel 64x64 img example
yl, yh = xfm1(x)
print(yl.shape)
print(len(yh),yh[0].shape)

I presume here yl and yh are low and high frequency components respectively. However, I see that the HF components are complex :(

I guess my question is, how do I (in a sensible way) get a 1D flattened real vector (amplitude?) from the yh components.

Apologies if this is daft question and thank you ever so much for this awesome package.

fbcotter commented 2 years ago

Hi @jamesrobertwilliams, sorry for the slow reply! I don't often get back to this package as have been looking at other things lately.

The best way to get the real high frequency components is to take the absolute value of the complex coefficients.

The real and imaginary parts are both DWTs, so it would also work to take either of these, but at edge boundaries the real and imaginary both oscillate and the magnitude of these is less stable - see problem 2 on page 2 of http://sigproc.eng.cam.ac.uk/foswiki/pub/Main/NGK/sp_mag_finalsub.pdf.

This will give you an output that's [1, 3, 6, 32, 32] , a high pass image for each channel in your input, and for the 6 directions/subbands of the DTCWT (see figure 15 of the above pdf). Flattening it can be done in many ways and depends on your application.

A common thing to do would be to find the normalized energy in each subband (or square of abs value) . E.g.

x = torch.randn(1, 1, 64, 64)
yl, yh = xfm1(x)

lowpass_energy = (yl**2).sum()
print(f"Lowpass energy is: {lowpass_energy}")
highpass_energies = torch.zeros(3, 6)
for j in range(3):
  for band in range(6):
    subband_energy = yh[j][:, :, band, :, :, 0]**2 + yh[j][:, :, band, :, :, 1]**2
    highpass_energies[j, band] = subband_abs.sum()
print(f"Highpass energies are:\n{highpass_energies}")
print(f"Input energy: {(x**2).sum()}")
print(f"Sum of total band energy: {lowpass_energy + highpass_energies.sum()}")

You'll notice that the energy of the input signal is very close to the sum of energies of the subbands, a result of the DTCWT being an orthonormal wavelet transform (see bottom right equation on page 11 of that pdf)