Adding a jax backend option in katsu.katsu_math.py via the BackendShim class, and revamping katsu.mueller functions to handle Jax’s sharp bits would be a great addition for differentiable Mueller calculus.
An example on how to add compatibility to functions in katsu.mueller with a simplified version of the linear_polarizer function
def linear_polarizer(a, shape=None):
“""docstring
"""
# returns array of zeros
M = _empty_mueller(shape)
ones = np.ones_like(a)
cos2a = np.cos(2 * a)
sin2a = np.sin(2 * a)
# fist row
if np.__name__ == ‘jax.numpy’:
M.at[..., 0, 0].set(ones)
else:
M[..., 0, 0] = ones
return M
Adding a jax backend option in
katsu.katsu_math.py
via theBackendShim
class, and revampingkatsu.mueller
functions to handle Jax’s sharp bits would be a great addition for differentiable Mueller calculus.An example on how to add compatibility to functions in
katsu.mueller
with a simplified version of thelinear_polarizer
functionHeres the link to all of Jax’s sharp bits