Closed adrn closed 3 weeks ago
The problem appears to be
@register(lax.gt_p)
def _gt_p_qv(x: AbstractQuantity, y: ArrayLike) -> ArrayLike:
try:
xv = ustrip(one, x)
except UnitConversionError:
return jnp.full(_bshape((x, y)), fill_value=False, dtype=bool)
# re-dispatch on the value
return qlax.gt(xv, y)
which appears in many of the comparison operators.
Partly solving this is trivial. We could just move the xv = ustrip(one, x)
out of the except block. This would fix x > 1.0
comparisons.
But then x > 0
would also fail. Fully solving this is tricky since we can't concretize y
for conditional logic.
Solved it! equinox.filter_jit
provides the magic sauce.
x = eqx.error_if( # TODO: customize Exception type
x,
x.unit != one and jnp.logical_not(jnp.all(y == 0)),
f"Cannot compare Q(x, {x.unit}) > y (except for y=0).",
)
return qlax.gt(x.value, y) # re-dispatch on the value
We lose the UnitConversionError, which isn't ideal, but maybe we can push a PR to equinox to allow for customizing the exception class. Either way, I think that's a secondary concern compared to the existing bug.
Example: