BGU-CS-VIL / WTConv

Wavelet Convolutions for Large Receptive Fields. ECCV 2024.
MIT License
203 stars 10 forks source link

A better padding method? #21

Open caixd-220529 opened 3 days ago

caixd-220529 commented 3 days ago

Thanks for the wonderful work. However, I have some questions about the code relative to padding. In the model, the padding is carried in two places, which are https://github.com/BGU-CS-VIL/WTConv/blob/26af4e2a72a0b6a0275a0f845428997a231204b5/wtconv/wtconv2d.py#L55 and https://github.com/BGU-CS-VIL/WTConv/blob/26af4e2a72a0b6a0275a0f845428997a231204b5/wtconv/util/wavelet.py#L31. I think the padding strategy works well only when the name of wavelet is set to 'db1'. According to my experiment, the padding policy in the code will lead to wrong reconstruct result when the name of wavelet is set to others, such as 'db2'.

def wtconv_inverse_wtconv_exploration_v1019():
    x = torch.arange(4)[1:].float()
    wavelet_name = 'db2'
    filters_lo = torch.nn.Parameter(torch.Tensor([pywt.Wavelet(wavelet_name).dec_lo[::-1]]), requires_grad=False)
    filters_hi = torch.nn.Parameter(torch.Tensor([pywt.Wavelet(wavelet_name).dec_hi[::-1]]), requires_grad=False)
    # a, b, c, d = pywt.Wavelet(wavelet_name).dec_lo[::-1]
    # filters_lo = torch.nn.Parameter(torch.Tensor([[a, b, c, d]]), requires_grad=False)
    # filters_hi = torch.nn.Parameter(torch.Tensor([[d, -c, b, -a]]), requires_grad=False)
    print('filters_lo:', filters_lo.data)
    print('filters_hi:', filters_hi.data)
    head_pad = tail_pad = len(pywt.Wavelet(wavelet_name).dec_lo) // 2 - 1
    if len(x) % 2 == 1:
        tail_pad += 1
    x_pad = F.pad(x, (head_pad, tail_pad))
    approximation = F.conv1d(x_pad.view(1, 1, -1), filters_lo.view(1, 1, -1), stride=2)
    detail = F.conv1d(x_pad.view(1, 1, -1), filters_hi.view(1, 1, -1), stride=2)
    print('approximation:', approximation)
    print('detail:', detail)
    reconstruct = (F.conv_transpose1d(approximation, filters_lo.view(1, 1, -1), stride=2, padding=[head_pad]) +
                   F.conv_transpose1d(detail, filters_hi.view(1, 1, -1), stride=2, padding=[head_pad]))
    print('reconstruct:', reconstruct)

Here, the resconstruction will be

reconstruct: tensor([[[0.7500, 2.0000, 3.0000, 0.0000]]])

where the first element is different from the input.

I think a better implimentation will be:

def conv_inverse_conv_exploration_v1019():
    x = torch.arange(4)[1:].float()
    wavelet_name = 'db2'
    filters_lo = torch.nn.Parameter(torch.Tensor([pywt.Wavelet(wavelet_name).dec_lo[::-1]]), requires_grad=False)
    filters_hi = torch.nn.Parameter(torch.Tensor([pywt.Wavelet(wavelet_name).dec_hi[::-1]]), requires_grad=False)
    # a, b, c, d = pywt.Wavelet(wavelet_name).dec_lo[::-1]
    # filters_lo = torch.nn.Parameter(torch.Tensor([[a, b, c, d]]), requires_grad=False)
    # filters_hi = torch.nn.Parameter(torch.Tensor([[d, -c, b, -a]]), requires_grad=False)
    print('filters_lo:', filters_lo.data)
    print('filters_hi:', filters_hi.data)
    head_pad = len(pywt.Wavelet(wavelet_name).dec_lo) - 2
    tail_pad = len(pywt.Wavelet(wavelet_name).dec_lo) - 1
    x_pad = F.pad(x, (head_pad, tail_pad))
    approximation = F.conv1d(x_pad.view(1, 1, -1), filters_lo.view(1, 1, -1), stride=2)
    detail = F.conv1d(x_pad.view(1, 1, -1), filters_hi.view(1, 1, -1), stride=2)
    print('approximation:', approximation)
    print('detail:', detail)
    reconstruct = (F.conv_transpose1d(approximation, filters_lo.view(1, 1, -1), stride=2, padding=[head_pad]) +
                   F.conv_transpose1d(detail, filters_hi.view(1, 1, -1), stride=2, padding=[head_pad]))
    print('reconstruct:', reconstruct)

In this way, the length of detail or approximation will not be the exact half of the input. But the reconstruction will the same as input. This change may improve the model performance when the name of wavelet is set to other instead of 'db1'.

I do not know whether I have explained it clearly since my English is poor. Thanks for the great work again!

shahaffind commented 2 days ago

Thanks for pointing this out.

There are indeed boundary issues when dealing with odd length inputs and/or with larger WT bases. We choose to keep the output resolution as half the input resolution (up to size 1 padding to deal with odd lengths), and as common practice for that we lose some accuracy around the boundaries.

Your fix can work, to put it in WTConv scheme it is equivalent to use

pad = ((2*filters.shape[2]-3) // 2, (2*filters.shape[3]-3) // 2)

in https://github.com/BGU-CS-VIL/WTConv/blob/26af4e2a72a0b6a0275a0f845428997a231204b5/wtconv/util/wavelet.py#L31 as well as in its inverse (L39)

This might improve the results for the other WT bases in table 8. However, it might create quite a mess with the shapes of the coefficients and result in unexpected behaviors and computational costs. I didn't expect it to improve much on the results so I didn't bother to implement and test it, if you are willing to try it out please let me know of the results :)