@inducer [more of an FYI than an issue?] This pertains to the Pyro version that uses arraycontext, PR #58, specifically when using EagerJAXArrayContext, with the goal of computing chemical Jacobians via jax.jacfwd. Everything works nicely up until then. The specific error is:
TypeError: EagerJAXArrayContext.full_like invoked with an unsupported array type: got 'JVPTracer', but expected one of (<class 'jaxlib.xla_extension.DeviceArrayBase'>,)
Initially, we interpreted this as not being able to differentiate through the Newton method. However, I applied jacfwd to this function:
def chemical_source_terms(state):
density = pyro_gas.get_density(pyro_gas.one_atm, state[0], state[1:])
return pyro_gas.wts[:, None] * pyro_gas.get_net_production_rates(density, state[0], state[1:]) / density
where the temperature is now an independent variable (state[0]), so no Newton is ever invoked. What is happening is that EagerJAXArrayContext does not implement the JVPTracer type that Jax uses to trace the code.
In summary, this is not so much on the Pyro side as it is on arraycontext, so the explicit nonlinear solve is not as urgent as it seemed last week (I think). I dug in and couldn't find a way to expose the JVPTracer through to arraycontext, but I think this is not the right space for that discussion.
@inducer [more of an FYI than an issue?] This pertains to the Pyro version that uses arraycontext, PR #58, specifically when using
EagerJAXArrayContext
, with the goal of computing chemical Jacobians viajax.jacfwd
. Everything works nicely up until then. The specific error is:TypeError: EagerJAXArrayContext.full_like invoked with an unsupported array type: got 'JVPTracer', but expected one of (<class 'jaxlib.xla_extension.DeviceArrayBase'>,)
Initially, we interpreted this as not being able to differentiate through the Newton method. However, I applied jacfwd to this function:
where the temperature is now an independent variable (
state[0]
), so no Newton is ever invoked. What is happening is thatEagerJAXArrayContext
does not implement theJVPTracer
type that Jax uses to trace the code.In summary, this is not so much on the Pyro side as it is on arraycontext, so the explicit nonlinear solve is not as urgent as it seemed last week (I think). I dug in and couldn't find a way to expose the
JVPTracer
through to arraycontext, but I think this is not the right space for that discussion.