Closed Randl closed 7 months ago
Looks like that the failed test is just unlucky seed?
So I think it's probably best to avoid casting absolutely everything to the same dtype, as discussed over in #64.
Probably the simplest thing to do is just to locally allow dtype casting -- I do this in a few places where it's awkward not to:
with jax.numpy_dtype_promotion("standard"):
...
As for an unlucky seed -- yup, that seems plausible. GMRES is suck a finickity algorithm. I'm not sure I have a nice way to guard against this error, so I'm happy to just ignore it.
So there are four cases where this PR introduces casts:
DivLinearOperator
operatorSo I think for 2 and 4 it makes sense to cast since the value multiplies the input, and for 1 it doesn't really matter. No strong opinion about 3. What do you think?
I think I'd probably suggest using with jax.numpy_dtype_promotion("standard")
for all of them, actually! For example on 2: note that right now, if you have a pytree of (f32, c64)
then the scalar will get promoted to result_type = c64
and as such the entire output will be as well.
LGTM -- thank you as always! Where are we at now for complex support in Lineax?
Split off https://github.com/patrick-kidger/lineax/pull/64, see https://github.com/patrick-kidger/lineax/issues/57