patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2k stars 132 forks source link

Question: `filter_shard` with PartitionSpecs, or other ways to ensure batched output, as in `vmap`. #759

Open johannahaffner opened 2 months ago

johannahaffner commented 2 months ago

Hi again! I have yet another question, this time to do with sharding.

Essentially, I'm trying to shard to improve performance, but I would like the output to be the same as that of vmap.

Background: I currently vmap to parallelise across my data, and I want to get a single value for each input array. (I do parameter estimation for ODEs across individuals.) I noticed that this approach only ever uses about 20 % of my CPU, and confirmed that it essentially runs on a single CPU core. I'd like to change that to multicore CPU, and then scale to GPUs.

What I did so far:

Here is some code that replicates the behavior I am talking about. Apologies for length, I tried to condense it as much as I could.

import pytest

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard

import equinox as eqx
import diffrax as dfx
import optimistix as optx

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

class ToyModel(eqx.Module):
    """Toy model that integrates an ODE."""
    _term: dfx.ODETerm
    def __init__(self, ode_model):
        self._term = dfx.ODETerm(ode_model)

    def __call__(self, param):
        sol = dfx.diffeqsolve(self._term, dfx.Tsit5(), 0, 30, 0.01, jnp.array([10.]), args=param, 
            saveat=dfx.SaveAt(ts=jnp.linspace(0, 30, 50)), adjoint=dfx.DirectAdjoint(), max_steps=16**4)
        return jnp.transpose(sol.ys)  # To get the shape I'm used to

class Estimator(eqx.Module):
    _solver: optx.AbstractLeastSquaresSolver
    def __init__(self, solver):
        self._solver = solver

    @eqx.filter_jit(donate="none")  # I hope that keeps it simpler for now
    def __call__(self, param, model, data):
        args = (model, data)
        def residuals(param, args):
            model, data = args
            fit = model(param)
            return data - fit
        solution = optx.least_squares(residuals, self._solver, param, args=args)
        return solution.value

# Create the model
def dydt(t, y, k): 
    return -k * y 
model = ToyModel(dydt)

# Generate fake data
rates = jnp.arange(0.1, 0.9, 0.1)  # Makes array with eight entries
ys = eqx.filter_vmap(model)(rates)

# Create the estimator
estimator = Estimator(optx.LevenbergMarquardt(atol=1e-06, rtol=1e-03))

# Solution: current approach (uses one core of my CPU even for very large data sets)
k0 = jnp.array([0.5]*8)
fitted_ks = eqx.filter_jit(eqx.filter_vmap(estimator, in_axes=(0, None, 0)))(k0, model, ys)  # Current approach

# Try something akin to the equinox data parallelism example
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()

sharded_estimator = eqx.filter_shard(estimator, replicated)
sharded_inputs = eqx.filter_shard((k0, model, ys), replicated)

with pytest.raises(ValueError):
    # This causes problems in diffrax: complaint about PyTree structure
    fitted_ks_with_sharding = sharded_estimator(*sharded_inputs)  

scalar_k0 = 0.5
sharded_inputs_with_scalar_k0 = eqx.filter_shard((scalar_k0, model, ys), replicated)
sharding_solutions = sharded_estimator(*sharded_inputs_with_scalar_k0)  # Seems to fit the average

sharding_solutions
johannahaffner commented 2 months ago

PS: bit random thought, will try later: should I shard the residuals function instead?

patrick-kidger commented 2 months ago

Ah, I think this is something JAX still needs better docs for.

In response to the various points you raise:

johannahaffner commented 2 months ago

Hi Patrick,

thank you for your thoughtful response! I always learn so much from your explanations.

I'm trying to whittle it down to a smaller thing. So far I noticed that it does not raise an error if I do not shard the inputs, but only the vmapped estimator. Then the output is what I would expect. My toy model is too small to check if it is really running on all cores in that case, though.

I'm trying to get it to work on the non-toy example, but that adds quite a bit of complexity and I'm not there yet :)