Open SimiPixel opened 9 months ago
A simple 2nd-order filter.
import jax
import jax.numpy as jnp
def second_order_butterworth(signal: jax.Array, f_sampling: int = 100, f_cutoff: int = 15, method: str = "forward_backward") -> jax.Array:
"https://stackoverflow.com/questions/20924868/calculate-coefficients-of-2nd-order-butterworth-low-pass-filter"
if method == "forward_backward":
signal = second_order_butterworth(signal, f_sampling, f_cutoff, "forward")
return second_order_butterworth(signal, f_sampling, f_cutoff, "backward")
elif method == "forward":
pass
elif method == "backward":
signal = jnp.flip(signal, axis=0)
else:
raise NotImplementedError
ff = f_cutoff / f_sampling
ita = 1.0 / jnp.tan(jnp.pi * ff)
q = jnp.sqrt(2.0)
b0 = 1.0 / (1.0 + q*ita + ita**2)
b1 = 2*b0
b2 = b0
a1 = 2.0 * (ita**2 - 1.0) * b0
a2 = -(1.0 - q*ita + ita**2) * b0
def f(carry, x_i):
x_im1, x_im2, y_im1, y_im2 = carry
y_i = b0 * x_i + b1 * x_im1 + b2 * x_im2 + a1 * y_im1 + a2 * y_im2
return (x_i, x_im1, y_i, y_im1), y_i
init = (signal[1], signal[0]) * 2
signal = jax.lax.scan(f, init, signal[2:])[1]
signal = jnp.concatenate((signal[0:1],) * 2 + (signal,))
if method == "backward":
signal = jnp.flip(signal, axis=0)
return signal
produces the same as the following scipy implementation apart from different behaviour at the edges. Most likely due to different (more sophisticated) padding strategy in this context.
def scipy_butterworth(signal, f_sampling = 100, f_cutoff = 15, N=2):
[b, a] = scipy.signal.butter(N, f_cutoff / (f_sampling / 2))
return scipy.signal.filtfilt(b, a, signal, axis=0)
I didn't seem to find the equivalent of
scipy.signal.butter
in thejax.scipy
module as of now. Would be great to have :)