PyWavelets / pywt

PyWavelets - Wavelet Transforms in Python
http://pywavelets.readthedocs.org
MIT License
1.97k stars 460 forks source link

Why Gabor wavelet does not implemented in CWT? #637

Open kaz0120 opened 2 years ago

kaz0120 commented 2 years ago

Hi there, at first, sorry for my poor English. I want to apply CWT on my signal data by Python, i found this library; PyWavelets. I used Morlet mother-wavelet, and get great result. Next, I found the Improved Gavor wavelet. Improved Gabor wavelet (IGW) is great one to understand the relationships between time ans frequency domains, so just I thought I wanna use this wavelet. I wrote the scratch code for IGW successfully, but it's toooooooo heavy to calculate the coefficients. So I refered pywt.cwt and rewrite the code as below; but I can't make sure how to implemet the calculations because IGW does not use the scale parameter. Why pywt.cwt compute the integrals? Is it faster method? How can I see its argorithm? Please help me...

import math
import numbers
from typing import List, Optional, Tuple, Union

import numpy as np
import pywt
from pywt import ContinuousWavelet

class ExpandingWavelet(object):
    import scipy.fft
    fftmodule = scipy.fft
    next_fast_len = fftmodule.next_fast_len

    def __init__(self, name: str):
        pass

    def __new__(cls, name: str, *args, **kwargs):
        if "gabor" in name:
            tmp = type("ExpandingWavelet", (object,), 
                       dict(wavefun      = ExpandingWavelet.gabor_wavefun,
                            _gabor_sigma = name.lower().rsplit("gabor")[-1],
                            complex_cwt  = True,
                            upper_bound  = 8.0,
                            lower_bound  = -8.0
                           )
                      )
            if len(tmp._gabor_sigma) == 0:
                tmp._gabor_sigma = 0.8909
            return tmp.__new__(cls)
        else:
            return ContinuousWavelet(name)

    def gabor_wavefun(self, 
        level: int = 8, 
        length: Optional[int] = None,
        f: numbers.Number = 1
    ) -> Tuple[np.ndarray, np.ndarray]:
        """ improved Gabor wavelet: Publisher: Ji Z., Yan S., Bao J., "An improved Gabor wavelet and 
        its complete transforms," 2015 IEEE ICSPCC, 2015, DOI: 10.1109/ICSPCC.2015.7338925
        """
        assert isinstance(level, int), "``level`` must be integer, not {}".format(type(level))
        maxlen = 2**level
        if length is None:
            length = maxlen
        x = np.linspace(self.lower_bound, self.upper_bound, 2**level)
        _coef = 1/(self._gabor_sigma * np.sqrt(2 * np.pi))
        psi = np.array([_coef * np.exp( -(f*t)**2/(2 * self._gabor_sigma**2) + 2 * np.pi * 1j * (f*t) )
                        for t in x], dtype=complex if self.complex_cwt else float)
        assert len(x) == len(psi)
        if maxlen == length:
            return (psi, x)
        else:
            _idxs = np.linspace(0, maxlen, length).astype(int)
            return (psi[_idxs], x[_idxs])

    def integrate(self, precision: int = 8) -> Tuple[np.ndarray, np.ndarray]:
        """ call ``pywt.integrate_wavelet``
        """
        return pywt.integrate_wavelet(self.wavefun(level=precision), precision=precision)

    def cwt(
        self, 
        data: np.ndarray, 
        freqs: Union[numbers.Number, List[numbers.Number], np.ndarray],
        precision: int = 10,
        sampling_period: int = 1,
        method: str = "conv", 
        axis: int = -1
    ) -> Tuple[np.ndarray, ...]:
        """ Continuous Wavelet Transform
        """
        dtype_data = type(data)
        data = np.asarray(data, dtype=dtype_data) 
        dtype_cplx = np.result_type(dtype_data, np.complex64)
        if np.isscalar(freqs):
            freqs = np.array([freqs])

        assert data.ndim == 1, exec(
            f'raise ValueError("Invalid shape for ``data``. 1-D data is only accepted, not {data.ndim=}")')
        assert freqs.min() > 0, exec(
            f'raise ValueError("Invalid value for ``freqs``. ``freqs`` must be greater than 0 (freqs > 0).")')

        assert isinstance(axis, int), exec(
            f'raise ValueError("Invalid value for ``axis``. Integer is required, not {axis=}")')
        assert isinstance(precision, int), exec(
            f'raise ValueError("Invalid value for ``precision``. Integer is required, not {precision=}")')
        assert method in ["fft", "conv"], exec(
            f'raise ValueError("Invalid value for ``method``. `fft` or `conv` are only supported, not {method=}")')

        dtype_out = dtype_cplx if self.complex_cwt else dtype_data
        out = np.empty((np.size(freqs),) + data.shape, dtype=dtype_out)

        (integ_psi, x) = self.integrate(precision=precision)
        integ_psi = np.conj(integ_psi) if self.complex_cwt else integ_psi

        ## convert integ_psi, x to the same preciion as the data
        dtype_psi = dtype_cplx if integ_psi.dtype.kind == "c" else dtype_data
        integ_psi = np.asarray(integ_psi, dtype=dtype_psi)
        x = np.asarray(x, dtype=data.real.dtype)
        x_scale = x[-1] - x[0]
        x_step = x[1] - x[0]

        for i, freq in enumerate(freqs):
            ~~~~ I CAN'T IMPLEMET HERE ~~~~

            if out.dtype.kind != "c":
                coef = coef.real
            d = (coef.shape[-1] - data.shape[-1]) / 2.
            if d > 0:
                coef = coef[..., math.floor(d):-math.ceil(d)]
            elif d < 0:
                raise ValueError(f"Selected freq of {freq} too small.")

            out[i, ...] = coef
        # end for

        frequencies = freqs  ## Improved Gabor makes corresponding between time-freq domain directly; not need scale domain
        if np.isscalar(frequencies):
            frequencies = np.array([frequencies])

        return (out, frequencies)
kaz0120 commented 2 years ago

When using Numba, the scratch code is below; could you fix the code to be faaaaaaaaaaaaaaster? Improved Gabor Wavelet ain't required any scale transformations, this is because it's easy to utilize and understand. But, without Numba, took toooooooooooo long times........


from numba import njit, prange, objmode, complex128
import numpy as np

@njit("c16(f8, f8)", fastmath=False)
def jit_gabor_t(
    t: float, 
    sigma: float = 0.8909
) -> complex:
    return 1/(sigma * np.sqrt(2 * np.pi)) * np.exp( -t**2/(2 * sigma**2) + 2 * np.pi * 1j * t )

@njit("c16(f8, f8, f8)", fastmath=False)
def jit_gabor_f_tau_t(
    f: float,
    tau: float,
    t: float
) -> complex:
    return abs(f) * jit_gabor_t(t=f * (t - tau), sigma=0.8909)

@njit("f8(f8, f8, f8)", fastmath=False)
def jit_calc_wavelet_window_width(
    f: float, 
    sigma: float, 
    amp: float = 0.005, 
) -> float:
    return 1/f * sigma * np.sqrt(-2 * np.log(amp))

@njit("c16(f8[:], f8[:], f8, f8)", fastmath=False)
def jit_gabor_J(
    sig: np.ndarray,
    time: np.ndarray, 
    f: float, 
    tau: float
) -> complex:
    tmp = time - tau
    for i in prange(len(tmp)):
        tmp[i] = abs(tmp[i])
    tauidx = np.argmin(tmp)
    wavelet_window_width = jit_calc_wavelet_window_width(f=f, sigma=0.8909, amp=0.005)
    _indexes = np.where(  (time[tauidx] - wavelet_window_width <= time) 
                        & (time <= time[tauidx] + wavelet_window_width) )
    idxs = np.arange(len(time))[_indexes]
    min_, max_ = idxs.min(), idxs.max()
    out = 0. + 0.j
    for i in range(min_, max_+1):
        out += sig[i] * jit_gabor_f_tau_t(f=f, tau=tau, t=time[i])
    return out

@njit("c16[:,:](f8[:], f8[:], f8[:])")
def jit_gabor(
    time: np.ndarray, 
    sig: np.ndarray, 
    freqs: np.ndarray
) -> np.ndarray:
    N_t = len(time)
    N_f = len(freqs)
    out = np.zeros((N_f, N_t), dtype=complex128)
    for fi in range(N_f):
        for ti in range(N_t):
            out[fi,ti] = jit_gabor_J(sig=sig, time=time, f=freqs[fi], tau=time[ti])
    return out

Sample:


import time

import matplotlib.pyplot as plt
import numpy

N_t = 200
N_f = 20
dt = 0.01  # sampling intervals
t = np.arange(-1, 1, dt)
sig  = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
start = time.time()
res = jit_gabor(time=t, sig=sig, freqs=np.arange(1, N_f).astype(float))
fin = time.time() - start
print(fin, "sec")

fig, ax = plt.subplots()
ax.imshow(np.abs(res), aspect='auto', origin='lower')
ax.twinx().plot(sig)
ax.set(xlabel="Time [sec]", ylabel="Frequency [Hz]")
plt.show()

image

grlee77 commented 2 years ago

Hi @kaz0120, unfortunately the original CWT contributor to PyWavelets is no longer active with the project and I am more familiar with the discrete transforms. In general it is mostly @rgommers and myself doing basic maintenance of the library at the moment, but neither of us has the bandwidth to develop new features.

That said, I have looked at the code briefly in the past and had done some research regarding your question about the integral. There is a summary with links to more info in this comment: https://github.com/PyWavelets/pywt/issues/531#issue-510962058