patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Add casts in additional place to satisfy strict typing for complex case #89

Closed Randl closed 7 months ago

Randl commented 8 months ago

Split off https://github.com/patrick-kidger/lineax/pull/64, see https://github.com/patrick-kidger/lineax/issues/57

Randl commented 8 months ago

Looks like that the failed test is just unlucky seed?

patrick-kidger commented 8 months ago

So I think it's probably best to avoid casting absolutely everything to the same dtype, as discussed over in #64.

Probably the simplest thing to do is just to locally allow dtype casting -- I do this in a few places where it's awkward not to:

with jax.numpy_dtype_promotion("standard"):
    ...

As for an unlucky seed -- yup, that seems plausible. GMRES is suck a finickity algorithm. I'm not sure I have a nice way to guard against this error, so I'm happy to just ignore it.

Randl commented 8 months ago

So there are four cases where this PR introduces casts:

  1. Finite difference in test helpers.
  2. Scalar in DivLinearOperatoroperator
  3. Scales in CG and friends
  4. eps in GMRES normalization.

So I think for 2 and 4 it makes sense to cast since the value multiplies the input, and for 1 it doesn't really matter. No strong opinion about 3. What do you think?

patrick-kidger commented 8 months ago

I think I'd probably suggest using with jax.numpy_dtype_promotion("standard") for all of them, actually! For example on 2: note that right now, if you have a pytree of (f32, c64) then the scalar will get promoted to result_type = c64 and as such the entire output will be as well.

patrick-kidger commented 7 months ago

LGTM -- thank you as always! Where are we at now for complex support in Lineax?