pybamm-team / PyBaMM

Fast and flexible physics-based battery models in Python
https://www.pybamm.org/
BSD 3-Clause "New" or "Revised" License
1.08k stars 533 forks source link

refactor multiprocessing and multiple inputs #4087

Open martinjrobins opened 4 months ago

martinjrobins commented 4 months ago

Description

Refactor the current multiprocessing implementation to push this responsibility onto the solver.

Related issues #3987, #3713, #2645, #2644, #3910

Motivation

At the moment users can request multiple parallel solves by passing in a list of inputs

solver.solve(t_eval, inputs = [ {'a': 1}, {'a': 2} ])

For all solvers apart from the Jax solver, this uses the python multiprocessing library to run these solves in parallel. However, this causes issues on Windows #3987, and this could probably be much more efficient if done at the solver level rather than up in Python.

Possible Implementation

I would propose that if a list of inputs of size n is detected, the model equations are duplicated n times before being passed to the solver. All the solvers we use have in-built parallel functionality that can be utilised to run this in parallel, and once we have a decent gpu solver this could be done on the gpu. After the solve is finished then the larger state vector will be split apart and passed to multiple Solution objects.

Additional context

At the moment this snippet captures how multiprocessing is implemented:

            ninputs = len(model_inputs_list)
            if ninputs == 1:
                new_solution = self._integrate(
                    model,
                    t_eval[start_index:end_index],
                    model_inputs_list[0],
                )
                new_solutions = [new_solution]
            else:
                if model.convert_to_format == "jax":
                    # Jax can parallelize over the inputs efficiently
                    new_solutions = self._integrate(
                        model,
                        t_eval[start_index:end_index],
                        model_inputs_list,
                    )
                else:
                    with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
                        new_solutions = p.starmap(
                            self._integrate,
                            zip(
                                [model] * ninputs,
                                [t_eval[start_index:end_index]] * ninputs,
                                model_inputs_list,
                            ),
                        )
                        p.close()
                        p.join()
martinjrobins commented 4 months ago

Main drawback of this approach is that it would be very inefficient if you wanted to use dense matrices for the jacobian, not sure if that would be an issue as we use sparse by default......

martinjrobins commented 4 months ago

another downside would be if a user wanted to rapidly change the number of inputs, in which case you would have to setup new equations every time which would be quite inefficient

martinjrobins commented 4 months ago

any comments welcome on this approach, the above is my current thinking but happy to consider alternatives before I dive into making changes: @BradyPlanden @rtimms @jsbrittain @valentinsulzer.

BradyPlanden commented 4 months ago

Sounds like a good plan. Thanks @martinjrobins.

One area that might require some thought is ensuring that threads are not over-provisioned. I.e. if users have a higher level python multiprocess job interacting with the solver parallelisation. A kwarg that allows the user to set the number of threads used by the solver might do the trick.

agriyakhetarpal commented 4 months ago

I would also recommend the same thing. If we were to switch to add support for Emscripten/Pyodide sometime soon, we would need to disable threading or put up guards around the threading or multiprocessing imports so that import pybamm and the rest of the functionality works well under a WASM runtime (wherein a browser cannot start a new thread or run subprocesses).

martinjrobins commented 4 months ago

The solvers would use openmp for running in parallel so this would be controlled by setting an environment variable OMP_NUM_THREADS, I'll make sure to document this functionality for users also using multiprocessing.

agriyakhetarpal commented 2 weeks ago

Reopened since #4449 noted that it was a partial fix (I guess GitHub assumes that "partially fixes X" is still "fixes X")