pybamm-team / liionpack

A battery pack simulation tool that uses the PyBaMM framework
https://liionpack.readthedocs.io/en/latest/
MIT License
87 stars 29 forks source link

Use Dask to run the cell simulations in parallel #30

Closed wigging closed 2 years ago

wigging commented 3 years ago

This isn't really an issue but more of a feature request or enhancement. Would it be possible to use Dask to run the cell simulations in parallel? I created a basic example (see below) of running the SPMe model for several discharges in parallel using Dask. Compare elapsed time with and without Dask by commenting out the appropriate section in main(). Elapsed times are given in the table below when running on an 8-core CPU.

Dask is made for massive parallelization and it's fairly easy to setup for CPU and GPU clusters. If it can be used with liionpack then it could provide a huge performance boost for large pack simulations. I haven't tried Casadi's parallel features beyond running on a single CPU but I don't think it will be easy to scale compared to using Dask.

Example Elapsed time
No Dask 8.57 seconds
Dask 3.83 seconds
import matplotlib.pyplot as plt
import pybamm
import time
from dask.distributed import Client

def generate_plots(discharge, t, capacity, current, voltage):

    def styleplot(ax):
        ax.legend(loc='best')
        ax.grid(color='0.9')
        ax.set_frame_on(False)
        ax.tick_params(color='0.9')

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], current[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Current [A]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(capacity[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Discharge capacity [Ah]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    plt.show()

def run_simulation(dis, t_eval):

    model = pybamm.lithium_ion.SPMe()

    param = model.default_parameter_values
    param['Current function [A]'] = '[input]'

    sim = pybamm.Simulation(model, parameter_values=param)
    sim.solve(t_eval, inputs={'Current function [A]': dis})

    return sim.solution

def main(client):
    tic = time.perf_counter()

    discharge = [4, 3.5, 3, 2.5, 2, 1.8, 1.5, 1]  # discharge currents [A]
    t_eval = [0, 4000]                            # evaluation time [s]

    # No Dask
    # ------------------------------------------------------------------------

    # label = 'no Dask'

    # sols = []
    # for dis in discharge:
    #     sol = run_simulation(dis, t_eval)
    #     sols.append(sol)

    # Dask
    # ------------------------------------------------------------------------

    label = 'Dask'

    lazy_sols = client.map(run_simulation, discharge, t_eval=t_eval)
    sols = client.gather(lazy_sols)

    # ------------------------------------------------------------------------

    t = []
    capacity = []
    current = []
    voltage = []

    for sol in sols:
        t.append(sol['Time [s]'].entries)
        capacity.append(sol['Discharge capacity [A.h]'].entries)
        current.append(sol['Current [A]'].entries)
        voltage.append(sol["Terminal voltage [V]"].entries)

    toc = time.perf_counter()
    print(f'Elapsed time ({label}) = {toc - tic:.2f} s')

    generate_plots(discharge, t, capacity, current, voltage)

if __name__ == '__main__':
    client = Client()
    print(client)
    main(client)
TomTranter commented 3 years ago

That's nice! I've heard good things about dask but not really played with it myself. I think the scaling issue we have may be to do with data transfer between the threads which happens on every time step. I implemented a shared memory parallel process pool on linux and it sped things up a lot as all the threads were just accessing different indices of the same shared array to lookup the state vector and populate back the results. At the time I used a package that only worked on Linux but maybe dask can help us do something similar. Adaptive time stepping will also help us out a little. I'll work on this tomorrow.

srikanthallu commented 3 years ago

@TomTranter, could you elaborate some more on data dependencies between threads. Is it because we are doing a coupled solve with thermal ? Also, what was the library you used for shared memory parallel process ? Adaptive time stepping is important and will definitely help us accelerate.

wigging commented 3 years ago

@TomTranter Let me know how I can help with this. I don't want you to have all the fun 😄. I can work on a separate development branch and push it up to the repo.

Also, Dask has a lot of dashboards that visualize the work load so you can see how the system is being used. For the example I posted, the Dask workers can be viewed in real-time in the browser using the URL from client.dashboard_link.

valentinsulzer commented 3 years ago

Ideally it would be good to only create the model once

model = pybamm.lithium_ion.SPMe()
param = model.default_parameter_values
param['Current function [A]'] = '[input]'
sim = pybamm.Simulation(model, parameter_values=param)

and parallelize only the solve step

sim.solve(t_eval, inputs={'Current function [A]': dis})

This might help with memory

wigging commented 3 years ago

@tinosulzer I agree. For the naive example above, I just put something together to see how Dask would work. I will run another example with changes to the model execution and report the results.

srikanthallu commented 3 years ago

@tinosulzer, So the idea is to stack a task pool and spawn independent solves. For this we need further understanding on the memory requirements for a coupled solve and how data is shared across.

valentinsulzer commented 3 years ago

This would be a good first issue to tackle for understanding memory requirements https://github.com/pybamm-team/PyBaMM/issues/1442

valentinsulzer commented 3 years ago

The way initial conditions are handled, especially when there are input parameters in the initial conditions, is also not perfect https://github.com/pybamm-team/PyBaMM/blob/6ccce9f817db8e5df69e6a3de90587b62e20180a/pybamm/solvers/base_solver.py#L1261-L1294

This might cause some difficulties with creating a single model and solving multiple times. That part of the code could do with a fresh pair of eyes tbh

wigging commented 3 years ago

I'm not sure if this is related to @tinosulzer previous comment, but I get several Dask errors if I pass the model object to the mapped function.

def run_simulation(dis, t_eval, model):
    param = model.default_parameter_values
    param['Current function [A]'] = '[input]'

    sim = pybamm.Simulation(model, parameter_values=param)
    sim.solve(t_eval, inputs={'Current function [A]': dis})

    return sim.solution

def main(client):
    discharge = [4, 3.5, 3, 2.5, 2, 1.8, 1.5, 1]  # discharge currents [A]
    t_eval = [0, 4000]                            # evaluation time [s]

    model = pybamm.lithium_ion.SPMe()
    lazy_sols = client.map(run_simulation, discharge, t_eval=t_eval, model=model)
    sols = client.gather(lazy_sols)
TomTranter commented 3 years ago

@srikanthallu the data transfer is from the main python loop to the threads and back again after a time integrating the casadi functions. Then we get a new OCV and Ri and use that in the circuit to balance the local currents and set the boundary conditions for the next integration. So you have to pass the state vector of each cell back and forth as well as the output variables. I think this is where things are slowing down but I can't be sure. The thermal problem actually is working independently right now but you are right you would need to include this in the global calculation steps (alongside the current balancing) if you had inter-cell heat transfer. @wigging definitely happy for you to work on this too. I've just been running through a few dask tutorials this morning to understand what it's doing a bit more. Probably the way forward is to generate the casadi functions in single cell format instead of mapped format (i.e. just don't call map on them when passing back) and then get dask to handle their parallelization if that works.

TomTranter commented 3 years ago

For reference this was what I did before https://github.com/pybamm-team/PyBaMM/issues/849

wigging commented 3 years ago

I created a dask branch for the Dask solver. See solve_dask() in solver_utils.py.

TomTranter commented 3 years ago

ok cool, I have made a serial version of the mapped_step which should be an easier starting place to make into a dask mapped step. Hopefully we would not need two versions of the solve method. I'll push what I've done to that branch

TomTranter commented 3 years ago

Hey so I pushed some changes and I think basically you wanna try and put as much of the code in the for loop in _serial_step into a dask process

wigging commented 3 years ago

@TomTranter Dask is complaining that some of the variable types are DM. It looks like these are types returned from Casadi. Can you edit _serial_step_dask and solve_dask so there are no Casadi objects? I don't know how to support non-standard Python objects with Dask. I'm still learning how to use it.

wigging commented 3 years ago

Apparently objects must be able to be serialized for use with Dask. This may be causing the issues I'm having with passing Casadi and PyBaMM objects to the mapped Dask function. http://distributed.dask.org/en/latest/serialization.html

TomTranter commented 3 years ago

I think the integrator will work is you pass numpy arrays... but I'm not sure

TomTranter commented 3 years ago

@wigging as you are working on the dask branch right now I won't pull down the changes from main as it will result in conflicts but you might wanna do this yourself before pushing too much more

wigging commented 3 years ago

@TomTranter I rebased the dask branch with main so everything is up-to-date with main. If you need to work on the dask branch make sure you pull (or force pull) the latest version to your local dask branch.

TomTranter commented 3 years ago

Why rebase instead of just merging changes back in. I think this messed something up for my local version but not sure to be honest

wigging commented 3 years ago

Works fine for me. And by "fine" I mean I'm still getting the same Dask errors as before the rebase. I had to update the inputs for solve_dask so maybe you didn't pull down those changes too.

TomTranter commented 3 years ago

I'd prefer it if you merged instead of rebased please. I'm not a git expert and it seems like the safer option. https://www.atlassian.com/git/tutorials/merging-vs-rebasing

TomTranter commented 3 years ago

I just had to delete my local copy and clone again to sort out my copy of that branch... Don't know what exact sequence of events led to that but I never usually have to do this if you just pull changes down

TomTranter commented 3 years ago

I know it's probably hard to see the wood from the trees right now with all the merging into main and direct committing which I'll stop doing

wigging commented 3 years ago

I deleted the dask branch. I'll upload it again when it's more presentable.

wigging commented 3 years ago

Alright, I re-uploaded the dask branch. It's in a more complete state now. See the _step_dask and solve_dask functions. I don't get any errors but it is not executing in parallel. I tried a 16p2s simulation for one charge/discharge cycle and it said it would take more than one hour to complete. You can view the Dask dashboard and see that the workers are not running in parallel. I have no idea right now why it's not running in parallel.

wigging commented 3 years ago

I prefer to rebase the branch with main to keep the git timeline clean. If you use merge, there's going to be a lot of duplicate merge commits in the log when (or if) this branch is merged with main. Before you make changes on your local dask branch, pull any remote changes with git pull -f origin dask. Then push your changes back up with git push origin dask or git push -f origin dask. I'll try to keep things up-to-date with main and do all the rebases if needed.

TomTranter commented 3 years ago

I'm getting issues again - maybe as I use sourcetree but I pulled down some changes in the dask branch and it's saying I have merge conflicts

I really really would prefer it if you didn't rebase

TomTranter commented 3 years ago

I just did a hard reset on my local machine to fix the issue git reset --hard origin/dask

TomTranter commented 3 years ago

@wigging I made a little progress on this. I added a self contained script called dask debug with the bare minimal integrator functionality getting mapped. If you look at the dashboard there is a lot of waiting with nothing happening which I guess is data transfer in between steps. Maybe turning all the states into one big dask array and indexing into it would speed things up if dask handles these things efficiently in memory. If not we can try sharedarray

wigging commented 3 years ago

I'm getting close to a Dask solution. I ran the example using Python's from concurrent.futures import ProcessPoolExecutor and chunked arrays which is giving good speed improvements. So now I'm adapting that approach for use with Dask. To be continued...

TomTranter commented 2 years ago

We now have a solution in the dask PR #66 but it's slower than just running the casadi integrator. I don't know how that mapped integrator works as the docs are very patchy and you can only get so far with debug but they must be doing something clever because it's pretty instantaneous to kick off. It looked like dask actors were going to be useful and might still be but are quite experimental. The idea is to persist an actor class on a cpu process which holds its own state variables. This is basically an array of solutions for that actor's chunk of batteries. It then steps through them individually and updates the solution but only passes back the output. This is more efficient for data transfer but it means keeping a lot of data in memory. So a new actor is created for each process and I guess there's some overhead with this. It doesn't seem to scale as well as casadi (about twice as slow to step solve excluding the extra time to set up the client and instantiate the actors). There is another package called ray which does something similar but their actor functionality is a bit more developed than dask and they also talk about ways to use shared memory between processes which may speed up sending inputs and outputs about. It's hard to profile the dask actors too as all their tasks are hidden from the task stream. I will check out Ray but perhaps going forward we should think about writing our own parallel wrapper to sundials... Maybe @martinjrobins has some thoughts on how long this would take?

wigging commented 2 years ago

Keep in mind that the reason to use Dask is to be able to scale to multi-node clusters (CPUs and GPUs). Can Casadi, Ray, etc. take advantage of such hardware or are they limited to a single machine?

TomTranter commented 2 years ago

Ray certainly can. We need to find out how the casadi map works. No answer from them on the forum yet

TomTranter commented 2 years ago

Addressed in pr #66