GalacticDynamics / unxt

Unitful Quantities in JAX
https://unxt.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
19 stars 3 forks source link

Comparison with quantity not working #244

Closed adrn closed 3 weeks ago

adrn commented 1 month ago

Example:

x = Quantity(jnp.array([1., 0, -1.0]), u.kpc)
x > 0.0  # returns incorrect boolean array
x > 1.0  # should fail, instead returns incorrect boolean array
nstarman commented 1 month ago

The problem appears to be

@register(lax.gt_p)
def _gt_p_qv(x: AbstractQuantity, y: ArrayLike) -> ArrayLike:
     try:
         xv = ustrip(one, x)
    except UnitConversionError:
         return jnp.full(_bshape((x, y)), fill_value=False, dtype=bool)

    # re-dispatch on the value
    return qlax.gt(xv, y)

which appears in many of the comparison operators. Partly solving this is trivial. We could just move the xv = ustrip(one, x) out of the except block. This would fix x > 1.0 comparisons. But then x > 0 would also fail. Fully solving this is tricky since we can't concretize y for conditional logic.

nstarman commented 1 month ago

Solved it! equinox.filter_jit provides the magic sauce.

     x = eqx.error_if(  # TODO: customize Exception type
        x,
        x.unit != one and jnp.logical_not(jnp.all(y == 0)),
        f"Cannot compare Q(x, {x.unit}) > y (except for y=0).",
    )
    return qlax.gt(x.value, y)  # re-dispatch on the value

We lose the UnitConversionError, which isn't ideal, but maybe we can push a PR to equinox to allow for customizing the exception class. Either way, I think that's a secondary concern compared to the existing bug.