pybamm-team / PyBaMM

Fast and flexible physics-based battery models in Python
https://www.pybamm.org/
BSD 3-Clause "New" or "Revised" License
876 stars 492 forks source link

Add support for MLIR-based expression evaluation #4199

Open jsbrittain opened 2 weeks ago

jsbrittain commented 2 weeks ago

Description

Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.

To enable the IREE/MLIR backend, set the (new) PYBAMM_IDAKLU_EXPR_IREE compiler flag ON via an environment variable and install PyBaMM using the developer method (by default PYBAMM_IDAKLU_EXPR_IREE is turned OFF):

export PYBAMM_IDAKLU_EXPR_IREE=ON
nox -e pybamm-requires && nox -e dev

Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (model.convert_to_format="jax") and setting the solver backend (jax_evaluator="iree"). Example:

import pybamm
import numpy as np

model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(model.default_geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

solver = pybamm.IDAKLUSolver(
    root_method="hybr",  # change from default ("casadi")
    options={"jax_evaluator": "iree"}
)
solution = solver.solve(model, np.linspace(0, 3600, 2500))

print(solution["Voltage [V]"].entries[:100])

Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).

Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):

Substituting a DFN model (and reducing atol = 1e-1) the times become:

There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.

Resolves #3826

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

Key checklist:

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

jsbrittain commented 2 weeks ago

Note: Codacy seems to be struggling with the template / inheritance structures in C++, hence the inclusion of additional // cppcheck-suppress comments in the IDAKLU solver.

codecov[bot] commented 2 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.57%. Comparing base (e22d10c) to head (efb0800). Report is 3 commits behind head on develop.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #4199 +/- ## =========================================== + Coverage 99.55% 99.57% +0.01% =========================================== Files 288 288 Lines 21856 22048 +192 =========================================== + Hits 21759 21954 +195 + Misses 97 94 -3 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

jsbrittain commented 1 week ago

ccing @cringeyburger here because this PR will be adding and modifying a lot of compiled code, we shall need to make a few adjustments in the migration to scikit-build-core as needed. Though, as long as the wheels builds pass (@jsbrittain, could you please trigger them on your fork?), it should be fine.

@agriyakhetarpal Wheels build fine: https://github.com/jsbrittain/PyBaMM/actions/runs/9665563143

jsbrittain commented 1 week ago

thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is jax_evaluator="iree" needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?

@martinjrobins yes, there is actually an existing python-idaklu interface that will run if we don't redirect using the (new) jax_evaluator option. I think it's a legacy item (idaklu/python.cpp) (it can be quite slow, even on these toy examples).

martinjrobins commented 1 week ago

that reminds me, we should get rid of python-idaklu, I don't think it serves a useful purpose anymore. I'll add an issue