v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
289 stars 36 forks source link

A question about WaveletPacket, WaveletPacket2D? #69

Closed Wongboo closed 1 year ago

Wongboo commented 1 year ago

I still don't understand how to custom level adapted to signal? Like the lower right part of image. Can adaption of signal in batches be different?

v0lta commented 1 year ago

Dear @Wongboo , the code computes a single transformation per Tensor. However, to apply different transforms to different signals of a batched tensor, you can split along the batch dimension. For example

import torch
import ptwt

trnd = torch.rand((10, 20, 20))
t1, t2 = torch.split(trnd, 5)
result_wavedec2 = ptwt.wavedec2(t1, "haar")
result_fswavedec2 = ptwt.fswavedec2(t2, "haar")
v0lta commented 1 year ago

Your question could also refer to accessing elements in a wavelet-packet tree. To access different elements in for example, the 1d tree, use the 'a' and 'd' keys, for example:

import torch, pywt, ptwt
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

t = np.linspace(0, 10, 1500)
w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
      wavelet=pywt.Wavelet("db3"), mode="reflect")

plt.plot(wp['a'][0]); plt.plot(wp['d'][0]); plt.show()
plt.plot(wp['aa'][0]); plt.plot(wp['da'][0]); plt.plot(wp['ad'][0]); plt.plot(wp['dd'][0]); plt.show()