Closed Peiyannn closed 6 months ago
@Peiyannn best place to start is to share some code that can reproduce your results
This is the code:
X = torch.randn(3, 8, 16, 16)
wavelet = pywt.Wavelet("bior1.3")
mode = "zero"
coef = ptwt.wavedec3(X_sub, wavelet, mode=mode, level=1)
coef2 = pywt.wavedec(X_sub.numpy(), wavelet, mode=mode, level=1)
Y_ = ptwt.waverec3(coef, wavelet)
Y_2 = pywt.waverec(coef2, wavelet)
rec_error_Y_ = torch.norm(
X_sub - Y_[:, : X_sub.shape[-3], : X_sub.shape[-2], : X_sub.shape[-1]]
) / torch.norm(X_sub)
rec_error_Y_2 = np.sqrt(
np.sum((X_sub.cpu().numpy() - Y_2[:, :, : X_sub.shape[-2], : X_sub.shape[-1]]) ** 2)
) / np.sqrt(np.sum((X_sub.cpu().numpy()) ** 2))
print(rec_error_Y_, rec_error_Y_2)
Torch and numpy use different precision settings by default. I think the numpy default is float64 while pytorch uses 32 bit.
In response you can manually ask pytorch for 64 bit floats see i.e. https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/1b75acdbe7c9cbfb526fd1a7c19ad260b935c6ec/tests/test_convolution_fwt_3.py#L107 .
In response you can manually ask pytorch for 64 bit floats see i.e.
.
Get it. Thanks!
@Peiyannn, based on your code, I wrote a minimal working example:
import torch
import numpy as np
import pywt
import ptwt
X_sub = torch.randn(3, 8, 16, 16).type(torch.float64)
wavelet = pywt.Wavelet("bior1.3")
mode = "zero"
coef = ptwt.wavedec3(X_sub, wavelet, mode=mode, level=1)
coef2 = pywt.wavedec(X_sub.numpy(), wavelet, mode=mode, level=1)
Y_ = ptwt.waverec3(coef, wavelet)
Y_2 = pywt.waverec(coef2, wavelet)
rec_error_Y_ = torch.norm(
X_sub - Y_[:, : X_sub.shape[-3], : X_sub.shape[-2], : X_sub.shape[-1]] ) / torch.norm(X_sub)
rec_error_Y_2 = np.sqrt(
np.sum((X_sub.cpu().numpy() - Y_2[:, :, : X_sub.shape[-2], : X_sub.shape[-1]]) ** 2)
) / np.sqrt(np.sum((X_sub.cpu().numpy()) ** 2))
print(rec_error_Y_, rec_error_Y_2)
print(f"Numerically equal: {np.allclose(Y_.numpy(), Y_2)}")
I hope it helps. I am closing this now. Feel free to reopen if you have any more questions.
Hello! I appreciate your work and am very interested in it.
I have a question I would like to ask: Why do I find difference in the data reconstructed using ptwt.wavedec3 and ptwt.waverec3 compared to when using pywt.wavedec and pywt.waverec? The reconstruction error with ptwt is considerably greater than with pywt. I tried changing the type of wavelet, and found that using some types of wavelets makes the situation much better. How can I mitigate this issue?