Open jsbrittain opened 2 weeks ago
Note: Codacy seems to be struggling with the template / inheritance structures in C++, hence the inclusion of additional // cppcheck-suppress
comments in the IDAKLU solver.
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.57%. Comparing base (
e22d10c
) to head (efb0800
). Report is 3 commits behind head on develop.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
ccing @cringeyburger here because this PR will be adding and modifying a lot of compiled code, we shall need to make a few adjustments in the migration to
scikit-build-core
as needed. Though, as long as the wheels builds pass (@jsbrittain, could you please trigger them on your fork?), it should be fine.
@agriyakhetarpal Wheels build fine: https://github.com/jsbrittain/PyBaMM/actions/runs/9665563143
thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is
jax_evaluator="iree"
needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?
@martinjrobins yes, there is actually an existing python-idaklu interface that will run if we don't redirect using the (new) jax_evaluator
option. I think it's a legacy item (idaklu/python.cpp
) (it can be quite slow, even on these toy examples).
that reminds me, we should get rid of python-idaklu, I don't think it serves a useful purpose anymore. I'll add an issue
Description
Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.
To enable the IREE/MLIR backend, set the (new)
PYBAMM_IDAKLU_EXPR_IREE
compiler flagON
via an environment variable and install PyBaMM using the developer method (by defaultPYBAMM_IDAKLU_EXPR_IREE
is turnedOFF
):Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (
model.convert_to_format="jax"
) and setting the solver backend (jax_evaluator="iree"
). Example:Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).
Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):
Substituting a DFN model (and reducing
atol = 1e-1
) the times become:There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.
Resolves #3826
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks: