pyrometheus / pyrometheus

Code generation for thermochemistry
Other
14 stars 10 forks source link

Jax Array Context and AD #59

Closed ecisneros8 closed 9 months ago

ecisneros8 commented 2 years ago

@inducer [more of an FYI than an issue?] This pertains to the Pyro version that uses arraycontext, PR #58, specifically when using EagerJAXArrayContext, with the goal of computing chemical Jacobians via jax.jacfwd. Everything works nicely up until then. The specific error is:

TypeError: EagerJAXArrayContext.full_like invoked with an unsupported array type: got 'JVPTracer', but expected one of (<class 'jaxlib.xla_extension.DeviceArrayBase'>,)

Initially, we interpreted this as not being able to differentiate through the Newton method. However, I applied jacfwd to this function:

def chemical_source_terms(state):
    density = pyro_gas.get_density(pyro_gas.one_atm, state[0], state[1:])
    return pyro_gas.wts[:, None] * pyro_gas.get_net_production_rates(density, state[0], state[1:]) / density

where the temperature is now an independent variable (state[0]), so no Newton is ever invoked. What is happening is that EagerJAXArrayContext does not implement the JVPTracer type that Jax uses to trace the code.

In summary, this is not so much on the Pyro side as it is on arraycontext, so the explicit nonlinear solve is not as urgent as it seemed last week (I think). I dug in and couldn't find a way to expose the JVPTracer through to arraycontext, but I think this is not the right space for that discussion.

ecisneros8 commented 9 months ago

The new version of Jax fixes this issue.