v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

A question about WaveletPacket, WaveletPacket2D? #69

Closed Wongboo closed 11 months ago

Wongboo commented 11 months ago

I still don't understand how to custom level adapted to signal? Like the lower right part of image. Can adaption of signal in batches be different?

v0lta commented 11 months ago

Dear @Wongboo , the code computes a single transformation per Tensor. However, to apply different transforms to different signals of a batched tensor, you can split along the batch dimension. For example

import torch
import ptwt

trnd = torch.rand((10, 20, 20))
t1, t2 = torch.split(trnd, 5)
result_wavedec2 = ptwt.wavedec2(t1, "haar")
result_fswavedec2 = ptwt.fswavedec2(t2, "haar")
v0lta commented 11 months ago

Your question could also refer to accessing elements in a wavelet-packet tree. To access different elements in for example, the 1d tree, use the 'a' and 'd' keys, for example:

import torch, pywt, ptwt
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

t = np.linspace(0, 10, 1500)
w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
      wavelet=pywt.Wavelet("db3"), mode="reflect")

plt.plot(wp['a'][0]); plt.plot(wp['d'][0]); plt.show()
plt.plot(wp['aa'][0]); plt.plot(wp['da'][0]); plt.plot(wp['ad'][0]); plt.plot(wp['dd'][0]); plt.show()