This is a good issue to pick up as a first jump into JAX. I would recommend first reading this introduction for JAX. A few hints would be:
[x] Don't recompute the filters every pass, these can be computed offline once and then passed as an argument (as their memory footprint is relatively small).
[x] JAX has direct equivalents to almost all numpy functions (just import jax.numpy as jnp and switch np to jnp)
[x] The way you index into arrays is more formal.
e.g. rather than x[y] = z you must run x = x.at[y].set(z) etc.
Also remember to consider static arguments (ones that don't change between evaluations i.e. L, this can help when JITing things.
This is a good issue to pick up as a first jump into JAX. I would recommend first reading this introduction for JAX. A few hints would be:
Also remember to consider static arguments (ones that don't change between evaluations i.e. L, this can help when JITing things.