chromatix-team / chromatix

Differentiable wave optics using JAX! Documentation can be found at https://chromatix.readthedocs.io
MIT License
77 stars 13 forks source link

Swap order of product with polarizer matrix #116

Closed gschlafly closed 5 months ago

gschlafly commented 6 months ago

When the polarizer is applied to the field, the order of operations should be polarizer @ field instead. Currently the order is the opposite. This is an issue when the polarizer is an asymmetric matric. The field may have to be transposed to be a column vector instead of a row vector.

https://github.com/chromatix-team/chromatix/blob/7304cd312b28eebc2f15c3c466e53074141d553b/src/chromatix/functional/polarizers.py#L77-L100

gschlafly commented 6 months ago

Example of dot product discrepancies

Here a vector field and a nonsymmetric polarizer matrix:

field = cf.plane_wave(
        (1, 1),
        1.0,
        0.532,
        1.0,
        amplitude=[0, 1, 2],
        scalar=False,
    )
polarizer = jnp.array([[0, 0, 0], [0, 1, 1], [0, -1, 1]])

We try a few dot product versions. jnp.dot seems to give the same output as jnp.matmul.

dot_prod1 = jnp.dot(polarizer, field.u.squeeze())
dot_prod2 = jnp.dot(field.u.squeeze(), polarizer)
# dot_prod1_wo_squeeze = jnp.dot(polarizer, field.u)
dot_prod2_wo_squeeze = jnp.dot(field.u, polarizer)
print(f"Dot product 1: {format_complex_array(dot_prod1)}")
print(f"Dot product 2: {format_complex_array(dot_prod2)}")
print(f"Dot product 1 without squeeze: ERROR")
print(f"Dot product 2 without squeeze: {format_complex_array(dot_prod2_wo_squeeze.squeeze())}")

Dot product 1: [0.00+0.00j, 1.34+0.00j, 0.45+0.00j] Dot product 2: [0.00+0.00j, -0.45+0.00j, 1.34+0.00j] Dot product 1 without squeeze: ERROR Dot product 2 without squeeze: [0.00+0.00j, -0.45+0.00j, 1.34+0.00j]

Dot product 1 is how the operations should be performed. Dot product 2 is how the operations are currently performed.

GJBoth commented 5 months ago

Fixed by #132