Jashcraf / katsu

Polarimetric Data Reduction and machine control for measuring the polarization of observatories
https://katsu.readthedocs.io
MIT License
3 stars 0 forks source link

Tests for hot-swappable numpy backends, demo use #23

Closed Jashcraf closed 3 months ago

Jashcraf commented 4 months ago

Adding a jax backend option in katsu.katsu_math.py via the BackendShim class, and revamping katsu.mueller functions to handle Jax’s sharp bits would be a great addition for differentiable Mueller calculus.

An example on how to add compatibility to functions in katsu.mueller with a simplified version of the linear_polarizer function

def linear_polarizer(a, shape=None):
    “""docstring
    """

    # returns array of zeros
    M = _empty_mueller(shape)

    ones = np.ones_like(a)
    cos2a = np.cos(2 * a)
    sin2a = np.sin(2 * a)

    # fist row
    if np.__name__ == ‘jax.numpy’:
        M.at[..., 0, 0].set(ones)

    else:
        M[..., 0, 0] = ones

    return M

Heres the link to all of Jax’s sharp bits

Jashcraf commented 4 months ago

Also will need to add a set_backend_to_jax function, much like the cupy one:

https://github.com/Jashcraf/katsu/blob/df77f5a631a9b858c1011b810f67422e051fa387/katsu/katsu_math.py#L31-L37