v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
289 stars 36 forks source link

Accuracy of 3D transform #83

Closed Peiyannn closed 6 months ago

Peiyannn commented 6 months ago

Hello! I appreciate your work and am very interested in it.

I have a question I would like to ask: Why do I find difference in the data reconstructed using ptwt.wavedec3 and ptwt.waverec3 compared to when using pywt.wavedec and pywt.waverec? The reconstruction error with ptwt is considerably greater than with pywt. I tried changing the type of wavelet, and found that using some types of wavelets makes the situation much better. How can I mitigate this issue?

cthoyt commented 6 months ago

@Peiyannn best place to start is to share some code that can reproduce your results

Peiyannn commented 6 months ago

This is the code:

X = torch.randn(3, 8, 16, 16)
wavelet = pywt.Wavelet("bior1.3")
mode = "zero"
coef = ptwt.wavedec3(X_sub, wavelet, mode=mode, level=1)
coef2 = pywt.wavedec(X_sub.numpy(), wavelet, mode=mode, level=1)
Y_ = ptwt.waverec3(coef, wavelet)
Y_2 = pywt.waverec(coef2, wavelet)
rec_error_Y_ = torch.norm(
    X_sub - Y_[:, : X_sub.shape[-3], : X_sub.shape[-2], : X_sub.shape[-1]]
) / torch.norm(X_sub)
rec_error_Y_2 = np.sqrt(
    np.sum((X_sub.cpu().numpy() - Y_2[:, :, : X_sub.shape[-2], : X_sub.shape[-1]]) ** 2)
) / np.sqrt(np.sum((X_sub.cpu().numpy()) ** 2))
print(rec_error_Y_, rec_error_Y_2)
v0lta commented 6 months ago

Torch and numpy use different precision settings by default. I think the numpy default is float64 while pytorch uses 32 bit.

v0lta commented 6 months ago

In response you can manually ask pytorch for 64 bit floats see i.e. https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/1b75acdbe7c9cbfb526fd1a7c19ad260b935c6ec/tests/test_convolution_fwt_3.py#L107 .

Peiyannn commented 6 months ago

In response you can manually ask pytorch for 64 bit floats see i.e.

https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/1b75acdbe7c9cbfb526fd1a7c19ad260b935c6ec/tests/test_convolution_fwt_3.py#L107

.

Get it. Thanks!

v0lta commented 6 months ago

@Peiyannn, based on your code, I wrote a minimal working example:

import torch
import numpy as np
import pywt
import ptwt

X_sub = torch.randn(3, 8, 16, 16).type(torch.float64)
wavelet = pywt.Wavelet("bior1.3")
mode = "zero"
coef = ptwt.wavedec3(X_sub, wavelet, mode=mode, level=1)
coef2 = pywt.wavedec(X_sub.numpy(), wavelet, mode=mode, level=1)
Y_ = ptwt.waverec3(coef, wavelet)
Y_2 = pywt.waverec(coef2, wavelet)
rec_error_Y_ = torch.norm(
X_sub - Y_[:, : X_sub.shape[-3], : X_sub.shape[-2], : X_sub.shape[-1]] ) / torch.norm(X_sub)
rec_error_Y_2 = np.sqrt(
     np.sum((X_sub.cpu().numpy() - Y_2[:, :, : X_sub.shape[-2], : X_sub.shape[-1]]) ** 2)
     ) / np.sqrt(np.sum((X_sub.cpu().numpy()) ** 2))
print(rec_error_Y_, rec_error_Y_2)
print(f"Numerically equal: {np.allclose(Y_.numpy(), Y_2)}")

I hope it helps. I am closing this now. Feel free to reopen if you have any more questions.