Integrating perturbation module with arraylias. Depends on #281.
Note that this PR also updates the perturbative solvers in the solvers folder.
Details and comments
Interface changes:
The backend argument for _CustomBinaryOp, which was used to determine whether to use JAX looping logic, has been dropped. _CustomBinaryOp.__call__ has correspondingly been updated to use _preferred_lib to determine whether to use JAX looping logic based on whether one of A or B is a JAX array.
solve_lmde_perturbation has been updated to use JAX control flow based on whether the ODE solver is JAX compatible, as opposed to checking Array.default_backend().
For perturbative solvers:
Updated the choice of whether to use JAX in the construction of ExpansionModel to be based purely on the integration_method used.
Updated .solve for the perturbative solvers to have the argument jax_control_flow for whether or not to use JAX when solving. If not specified, this is determined based on the types used in the underlying expansion model, the initial state type, and whether or not the method is being called within a JAX transformation.
Summary
Integrating perturbation module with arraylias. Depends on #281.
Note that this PR also updates the perturbative solvers in the
solvers
folder.Details and comments
Interface changes:
backend
argument for_CustomBinaryOp
, which was used to determine whether to use JAX looping logic, has been dropped._CustomBinaryOp.__call__
has correspondingly been updated to use_preferred_lib
to determine whether to use JAX looping logic based on whether one ofA
orB
is a JAX array.solve_lmde_perturbation
has been updated to use JAX control flow based on whether the ODE solver is JAX compatible, as opposed to checkingArray.default_backend()
.ExpansionModel
to be based purely on theintegration_method
used..solve
for the perturbative solvers to have the argumentjax_control_flow
for whether or not to use JAX when solving. If not specified, this is determined based on the types used in the underlying expansion model, the initial state type, and whether or not the method is being called within a JAX transformation.Test command: