On an affine solvers: such systems can be handled with a single linear solve. JAX can detect affine functions via
import jax
import jax.interpreters.partial_eval as pe
def is_affine(f, *args, **kwargs):
jaxpr = jax.make_jaxpr(jax.jacfwd(f))(*args, **kwargs)
_, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
return all(not x for x in used_inputs)
Powell's (unconstrained) derivative free optimisers:
On an affine solvers: such systems can be handled with a single linear solve. JAX can detect affine functions via