PyWavelets / pywt

PyWavelets - Wavelet Transforms in Python
http://pywavelets.readthedocs.org
MIT License
1.97k stars 460 forks source link

Why does the DWT give values that should be 0 but are in fact around 1e-16? #755

Open Ollie-spoon opened 6 days ago

Ollie-spoon commented 6 days ago

I've been using the DWT and it seems like a great tool but I've repeatedly come across this issue where there seems to be floating point error at around 1e-16 rather than at the smallest value that np.float64 can represent, ~1e-300. Additionally, these values seem to contribute to the transform somewhat distorting the signal but more than this it just makes this tool a lot harder to work with.

What I want to do is use the dwt for denoising, by taking the regions of the detail signals that go to zero with a noiseless signal, and setting these regions to zero on a noisy signal. My logic is that the linear nature of the dwt transfomrm should ensure that the signal is not modified but I'm finding that the DWT alone is modifying the signal.

signal: exp = 0.2*exp(-t/50.850886977157955) + 0.5*exp(-t/139.33995606900018) + 0.3*exp(-t/235.1443946374417)

Plot of my triexponential decay signal decomposed: dwt_detail_all

Plot of the error between the reconstructed signal subtracted from the original signal: dwt_reconstruction_error

If this is an error that I have made then fair enough but it has cropped its head up in every single dwt plot that I have produced so I'm leaning away from this conclusion. Assuming that it's not user error, is there a way I can modify the dwt code to ensure that the zero points are lower than ~1e-16?

Code included for completeness:

import numpy as np
import pywt
import matplotlib.pyplot as plt

# initialize eponential
tau1 = 40+np.random.normal(scale=8)
tau2 = 140+np.random.normal(scale=20)
tau3 = 225+np.random.normal(scale=20)
A1 = 0.5
A2 = 1.25
A3 = 0.75
dtype = np.float64
sample_rate = 10
t_cutoff = 1275
t = np.linspace(0, t_cutoff, int(t_cutoff*sample_rate), dtype=dtype)
V = A1*np.exp(-t/tau1) + A2*np.exp(-t/tau2) + A3*np.exp(-t/tau3)

# log10 function that can handle zero values
def safe_log10(arr):
    def replace_zeros(arr):
        zero_indices = np.where(arr == 0)[0]
        while zero_indices.size > 0:
            for index in zero_indices:
                if index == 0:
                    arr[index] = arr[index + 1]  # Replace with next value for first element
                elif index == len(arr) - 1:
                    arr[index] = arr[index - 1]  # Replace with previous value for last element
                else:
                    left_val = arr[index - 1]
                    right_val = arr[index + 1]
                    if left_val == 0:
                        arr[index] = right_val  # Replace with right value if left is also zero
                    elif right_val == 0:
                        arr[index] = left_val  # Replace with left value if right is also zero
                    else:
                        arr[index] = (left_val + right_val) / 2  # Average of left and right
            zero_indices = np.where(arr == 0)[0]
        return arr

    # Replace zeros with appropriate values
    arr = replace_zeros(arr)
    # Compute log10 safely
    log_arr = np.log10(arr)
    return log_arr

# iterate through wavelet list
wavelet_list = ['sym11', 'sym12', 'sym13', 'sym14', 'sym15', 'db12', 'db13', 'db14', 'db15', 'db16']
V_reconstructed = np.copy(V)
for wavelet in wavelet_list:
    dwt_data = pywt.wavedec(V_reconstructed, wavelet, mode='symmetric', level=None, axis=-1)

    approximation_coefficients = dwt_data[0]

    # dwt_data[1:] contains the detail coefficients for each level
    detail_coefficients = dwt_data[1:]

    plt.figure(figsize=(12,6))

    # Plot Approximation Coefficients
    plt.subplot(2, 1, 1)
    plt.plot(safe_log10(approximation_coefficients), label='Approximation Coefficients (A)')
    plt.title(f'Approximation Coefficients (A) for {wavelet}')
    plt.legend()
    plt.grid(True)

    # Plot Detail Coefficients
    plt.subplot(2, 1, 2)
    for i, detail in enumerate(detail_coefficients):
        t_ = np.linspace(0, len(detail_coefficients[-1]), len(detail)) 
        plt.plot(t_, safe_log10(np.abs(detail)), label=f'Detail Coefficients (D{i+1})')
    plt.title('Detail Coefficients (D1, D2, ..., Dn)')
    plt.legend()
    plt.grid(True)

    V_reconstructed = pywt.waverec(dwt_data, wavelet, mode='symmetric')

plt.figure(figsize=(14, 6))
plt.plot(t, V_reconstructed-V, color="orange")
plt.show()
Ollie-spoon commented 6 days ago

As a follow up to this, I had a go at tying to modify the code on the pywt repo to allow long double values which I believe correspond to np.float80 which would hopefully give an increased precision and overall lower error, however, this is definitely above my paygrade, I've never used cpython before.

rgommers commented 5 days ago

Hi @Ollie-spoon, thanks for the question.

Assuming that it's not user error, is there a way I can modify the dwt code to ensure that the zero points are lower than ~1e-16?

I didn't have the time to check your code in detail, but a relative error of 1e-16 is always expected for float64 operations. However, absolute errors can and should be much smaller if the amplitude of the input arrays isn't ~O(1).

I had a go at tying to modify the code on the pywt repo to allow long double values which I believe correspond to np.float80 which would hopefully give an increased precision

long double is a huge can of worms, and on Windows and macOS arm64 doesn't increase precision at all. So we almost certainly don't want to deal with that.

What I want to do is use the dwt for denoising, by taking the regions of the detail signals that go to zero with a noiseless signal, and setting these regions to zero on a noisy signal.

Normally if the background floating-point noise is relevant, you use a thresholding function to zero out the elements below some threshold.