pybamm-team / liionpack

A battery pack simulation tool that uses the PyBaMM framework
https://liionpack.readthedocs.io/en/latest/
MIT License
89 stars 29 forks source link

using parallel jax solver #61

Open martinjrobins opened 2 years ago

martinjrobins commented 2 years ago

@TomTranter : here is an example script that runs 20 instances of the spm model in parallel using jax. Let me know if this is what you need.

import time
import pybamm
import numpy as np
import jax
import matplotlib.pylab as plt

import os
# specify 20 logical devices for execution
ncpu = 20
os.environ['XLA_FLAGS'] = (
    '--xla_force_host_platform_device_count={}'.format(ncpu)
)

# print out the available devices

print('devices', jax.devices())

pybamm.set_logging_level("INFO")
model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
model.events = []

# create geometry
geometry = model.default_geometry

# load parameter values and process model and geometry
param = model.default_parameter_values
parameter = "Electrode height [m]"
value = param[parameter]
param.update({parameter: "[input]"})
param.process_model(model)
param.process_geometry(geometry)

# set mesh
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)

# discretise model
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# solve model for 1 hour
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver()

# the model is setup and compiled (expensive) during this call
solution = solver.solve(
    model, t_eval,
    inputs={parameter: value},
)

# create a new jax function that can take an array of inputs
mapped_jax_function = jax.pmap(
    solver.get_solve(model, t_eval),
)

# now our input is an array of "parameter",
# size needs to be the same as the number
# of devices
inputs_array = {
    parameter: jax.numpy.linspace(value / 10, value * 10, ncpu)
}

# the multiple inputs are executed in parallel here,
# the result is a 3d array of shape (Ni, Ns, Nt), where
# Ni is the size of the parameter array, Ns is the size of the
# full state vector, and Nt is the number of timesteps in t_eval
print('running in parallel')
tic = time.perf_counter()
result_array = mapped_jax_function(inputs_array)

# access the result so its actually computed
print(result_array[0, 0, 0])
toc = time.perf_counter()
print('time elapsed: {} sec', toc - tic)
TomTranter commented 2 years ago

Fantastic thanks. Will you be joining the dev meeting later? I will play with it now and may have a few questions

martinjrobins commented 2 years ago

I've got a shortlisting meeting that should end at 2pm, I can join after that or give you a call once its finished

wigging commented 2 years ago

@martinjrobins Can Jax perform in parallel using multiple machines?

wigging commented 2 years ago

I tried to run the Jax example that @martinjrobins posted above but I get the following error:

ValueError: model.timescale must be a Scalar after parameter processing
(cannot contain 'InputParameter's). You have probably set one of the
parameters used to calculate the timescale to an InputParameter. To avoid
this error, hardcode model.timescale to a constant value by passing the
option {'timescale': value} to the model.
TomTranter commented 2 years ago

@wigging yes we changed how timescale is implemented recently in pybamm. Can you run it with the suggested option for a scalar timescale

wigging commented 2 years ago

Using model = pybamm.lithium_ion.SPM(options={"timescale": 1.0}) makes the example work.