jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Low pass filters - Butterworth filter #17540

Open simon-bachhuber opened 1 year ago

simon-bachhuber commented 1 year ago

I didn't seem to find the equivalent of scipy.signal.butter in the jax.scipy module as of now. Would be great to have :)

simon-bachhuber commented 1 year ago

A simple 2nd-order filter.

import jax
import jax.numpy as jnp

def second_order_butterworth(signal: jax.Array, f_sampling: int = 100, f_cutoff: int = 15, method: str = "forward_backward") -> jax.Array:
    "https://stackoverflow.com/questions/20924868/calculate-coefficients-of-2nd-order-butterworth-low-pass-filter"
    if method == "forward_backward":
        signal = second_order_butterworth(signal, f_sampling, f_cutoff, "forward")
        return second_order_butterworth(signal, f_sampling, f_cutoff, "backward")
    elif method == "forward":
        pass
    elif method == "backward":
        signal = jnp.flip(signal, axis=0)
    else:
        raise NotImplementedError

    ff = f_cutoff / f_sampling
    ita = 1.0 / jnp.tan(jnp.pi * ff)
    q = jnp.sqrt(2.0)
    b0 = 1.0 / (1.0 + q*ita + ita**2)
    b1 = 2*b0
    b2 = b0
    a1 = 2.0 * (ita**2 - 1.0) * b0
    a2 = -(1.0 - q*ita + ita**2) * b0

    def f(carry, x_i):
        x_im1, x_im2, y_im1, y_im2 = carry
        y_i = b0 * x_i + b1 * x_im1 + b2 * x_im2 + a1 * y_im1 + a2 * y_im2
        return (x_i, x_im1, y_i, y_im1), y_i

    init = (signal[1], signal[0]) * 2
    signal = jax.lax.scan(f, init, signal[2:])[1]
    signal = jnp.concatenate((signal[0:1],) * 2 + (signal,))

    if method == "backward":
        signal = jnp.flip(signal, axis=0)

    return signal

produces the same as the following scipy implementation apart from different behaviour at the edges. Most likely due to different (more sophisticated) padding strategy in this context.

def scipy_butterworth(signal, f_sampling = 100, f_cutoff = 15, N=2):
    [b, a] = scipy.signal.butter(N, f_cutoff / (f_sampling / 2))
    return scipy.signal.filtfilt(b, a, signal, axis=0)
Saanidhyavats commented 3 months ago

This sounds to be a good feature. Is anyone implementing it? I was thinking to start on this project

hshi74 commented 1 month ago

Any updates on this thread?