sh-lee-prml / PeriodWave

The official Implementation of PeriodWave and PeriodWave-Turbo
MIT License
107 stars 7 forks source link

How did you deal with DWT shape differences? #3

Open zaptrem opened 3 weeks ago

zaptrem commented 3 weeks ago

screenshot

Each band that comes out of DWT has a different length. How did you fit them all into the model where all inputs have to be the same length? Can you explain the actual shape changes a waveform goes through as it makes its way to and from your model? e.g., Wave B, T -> ??? -> Wave B, T? The diagram in the paper is unclear since DWT spits out lots of different sequence lengths.

sh-lee-prml commented 3 weeks ago

Thanks for your interest.

[B,1,T] --> DWT --> [B,4,T//4]

We utilize the second dim as each target dwt components

[B,DWT_i:DWT_i+1,T//4] is x1 for each model.

For progressive generation that generates lower band first,

we could condition the DWT components of previous bands together [B,:DWT_i,T//4]

So, the input of networks for training is

Concat(x1 of [B,:DWT_i,T//4], xt of [B,DWT_i:DWT_i+1,T//4], dim=1)

The output of the network is the vector field of xt for [B,DWT_i:DWT_i+1,T//4].

For inference, the gt x1 of [B,:DWT_i,T//4] will be replaced with the generated dwt components.

For waveform reconstruction,

[B,4,T//4] --> iDWT --> [B,1,T]

Thanks!

zaptrem commented 3 weeks ago

Thanks for the prompt response! Doesn't DWT output something like the below shape (since by definition it trades off time and frequency resolution)? T is different for all of them. Below is from pytorch_wavelets

import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D
dwt = DWT1DForward(wave='db6', J=3)
X = torch.randn(10, 5, 100)
yl, yh = dwt(X)
print(yl.shape)
>>> torch.Size([10, 5, 22])
print(yh[0].shape)
>>> torch.Size([10, 5, 55])
print(yh[1].shape)
>>> torch.Size([10, 5, 33])
print(yh[2].shape)
>>> torch.Size([10, 5, 22])
idwt = DWT1DInverse(wave='db6')
x = idwt((yl, yh))

Which DWT library are you using?

sh-lee-prml commented 3 weeks ago
import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D

y = torch.randn(10, 1, 128)

dwt = DWT1DForward()
idwt = DWT1DInverse()

x_dwt1, x_dwt2 = dwt(y)
x_dwt1_a, x_dwt1_b = dwt(x_dwt1)
x_dwt2_a, x_dwt2_b = dwt(x_dwt2[0])

print(x_dwt1_a.shape)
print(x_dwt1_b[0].shape)
print(x_dwt2_a.shape)
print(x_dwt2_b[0].shape)

x1 = torch.concat([x_dwt1_a, x_dwt1_b[0], x_dwt2_a, x_dwt2_b[0]], dim=1)

print(x1.shape)

x_low = idwt([x_dwt1_a, [x_dwt1_b[0]]])
x_high = idwt([x_dwt2_a, [x_dwt2_b[0]]])
x = idwt([x_low, [x_high]])

print(x.shape)

Try this, and Time T//4 should be 0. You can use zero-padding.

zaptrem commented 3 weeks ago

Ah, that makes sense. Thanks!