fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

precision problem of DWTForward&DWTInverse #23

Closed zhouhuanxiang closed 3 years ago

zhouhuanxiang commented 3 years ago

I find that there is obvious difference between before pytorch_wavelets.DWTForward and after pytorch_wavelets.DWTInverse.

import torch import pytorch_wavelets

DWTForward = pytorch_wavelets.DWTForward() DWTInverse = pytorch_wavelets.DWTInverse()

x = torch.abs(torch.randn(17, 7, 32, 32)) yl, yh = DWTForward(x) y = DWTInverse((yl, yh)) print(torch.sum(torch.abs(y - x)))

and the difference is about 0.0080

fbcotter commented 3 years ago

Yes, for single point precision there is a limit to how accurate things can be. I wouldn't call this an obvious difference however, you are summing up the error over lots of locations. If I follow your code with a slight modification:

import torch
import pytorch_wavelets

DWTForward = pytorch_wavelets.DWTForward()
DWTInverse = pytorch_wavelets.DWTInverse()

x = torch.abs(torch.randn(17, 7, 32, 32))
yl, yh = DWTForward(x)
y = DWTInverse((yl, yh))
e = y - x

then the standard deviation of e is around 1e-7, giving a SNR of 137dB, which by all accounts is a very very small error. If you do need more precision however, then you can move to double precision with torch.set_default_dtype(torch.float64) before creating the DWT objects. This reduces the std of e to 1e-16