patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Parallel multi start #67

Open MaAl13 opened 1 week ago

MaAl13 commented 1 week ago

Hello,

first of all i wanted to say that this is a really nice library. Then i wanted to ask if it is possible to do with filter_vmap something like a parallel multi start optimization. Meaning that if i want to run the minimization from different starting parameter guesses, can this be parallelized?

MaAl13 commented 1 week ago

I have implemented sth like the code below. But it doesn't seem to work:

def optimize_batch(initial_guesses, compute_mse):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            solver = optx.BFGS(rtol=1e-5, atol=1e-5)
            sol = optx.minimise(compute_mse, solver = solver, y0 = initial_guess, max_steps = 5000)
            loss = compute_mse(sol.value, None)
            success = True
            #print(loss)
        except:
            loss = np.inf
            success = False
        return loss, success

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():

    # Generate initial guesses
    num_samples = 1024
    sobol = qmc.Sobol(d=len(simulation.flexible_params_indices), scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    losses, successes = optimize_batch(initial_guesses, compute_mse)

    losses = jnp.where(successes, losses, jnp.inf).astype('float64')

    # Find the best result
    best_index = jnp.argmin(losses)
    best_loss = losses[best_index]
    best_params = initial_guesses[best_index]

    print(f"Best loss: {best_loss} at index {best_index}")

if __name__ == "__main__":
    main()
patrick-kidger commented 1 week ago

Yup, it should be possible to vmap over the initial condition.

Unfortunately your example isn't runnable (I don't know what qmc is).

MaAl13 commented 6 days ago

Hello, thanks for the response. It is scipy.stats.qmc :) Here should be a runnable example code:

import optimistix as optx
import numpy as np
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float
import diffrax as dfx
from contextlib import contextmanager
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
# JAX configuration
config.update("jax_enable_x64", True)

@contextmanager
def suppress_warnings():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        yield

def vector_field(
    t, y: Float[Array, "2"], parameters: Float[Array, "4"]
) -> Float[Array, "2"]:
    prey, predator = y
    α, β, γ, δ = parameters
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.stack([d_prey, d_predator])
    return d_y

def solve(
    parameters: Float[Array, "4"], y0: Float[Array, "2"], saveat: dfx.SaveAt
) -> Float[Array, "ts"]:
    """Solve a single ODE."""
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    t0 = saveat.subs.ts[0]
    t1 = saveat.subs.ts[-1]
    dt0 = 0.1
    sol = dfx.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        args=parameters,
        saveat=saveat,
        adjoint=dfx.DirectAdjoint(),
    )
    return sol.ys

# Generate noisy measurement data
def generate_noisy_data(true_params, y0, ts, noise_level=0.1):
    true_solution = solve(true_params, y0, dfx.SaveAt(ts=ts))
    key = jr.PRNGKey(0)
    noise = jr.normal(key, true_solution.shape) * noise_level
    return true_solution + noise

# Compute MSE loss
def compute_mse(params, data):
    y0, ts, noisy_data = data
    predicted = solve(params, y0, dfx.SaveAt(ts=ts))
    return jnp.mean((predicted - noisy_data)**2)

# def single_optimization(initial_guess, data):
#     solver = optx.BFGS(rtol=1e-5, atol=1e-5)
#     try:
#         sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
#         loss = compute_mse(sol.value, data)
#         best_params = sol.value
#         success = True
#     except:
#         loss = jnp.inf
#         success = False
#         best_params = initial_guess
#     return best_params, loss, success

# def main():
#     # True parameters and initial condition
#     true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
#     y0 = jnp.array([10.0, 5.0])
#     ts = jnp.linspace(0, 30, 100)

#     # Generate noisy data
#     noisy_data = generate_noisy_data(true_params, y0, ts)

#     # Package data for optimization
#     data = (y0, ts, noisy_data)

#     # Generate initial guesses
#     num_samples = 100
#     sobol = qmc.Sobol(d=4, scramble=True)
#     initial_guesses = jnp.array(sobol.random(n=num_samples))

#     losses = np.zeros(num_samples)
#     successes = []
#     solutions = np.zeros_like(initial_guesses)

#     # Run the batch optimization
#     for i, initial_guess in tqdm(enumerate(initial_guesses), total=num_samples, desc="Optimizing", ncols=100):
#         with suppress_warnings():
#             try:
#                 best_params, loss, success = single_optimization(initial_guess, data)
#                 successes.append(success)
#                 solutions[i, :] = best_params if success else np.inf
#             except Exception as e:
#                 loss = np.inf
#                 successes.append(False)
#                 solutions[i, :] = initial_guess
#         losses[i] = loss

#     print(f"Lowest loss: {np.nanmin(losses)}")
#     print(f"Lowest index: {np.nanargmin(losses)}")

#     best_parameters = solutions[np.nanargmin(losses),:]

#     # Generate the best-fit trajectory
#     best_trajectory = solve(best_parameters, y0, dfx.SaveAt(ts=ts))

#     # Plotting
#     fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))

#     # Plot prey
#     ax1.scatter(ts, noisy_data[:, 0], color='red', alpha=0.5, label='Noisy Data (Prey)')
#     ax1.plot(ts, best_trajectory[:, 0], color='blue', label='Best Fit (Prey)')
#     ax1.set_ylabel('Prey Population')
#     ax1.legend()
#     ax1.set_title('Prey Population: Noisy Data vs Best Fit')

#     # Plot predator
#     ax2.scatter(ts, noisy_data[:, 1], color='green', alpha=0.5, label='Noisy Data (Predator)')
#     ax2.plot(ts, best_trajectory[:, 1], color='orange', label='Best Fit (Predator)')
#     ax2.set_xlabel('Time')
#     ax2.set_ylabel('Predator Population')
#     ax2.legend()
#     ax2.set_title('Predator Population: Noisy Data vs Best Fit')

#     plt.tight_layout()
#     plt.show()

#     print(f"True parameters: {true_params}")
#     print(f"Best parameters: {best_parameters}")

# if __name__ == "__main__":
#     main()

def optimize_batch(initial_guesses, compute_mse, data):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
            loss = compute_mse(sol.value, data)
            success = True
            best_params = sol.value  
        except:
            loss = jnp.inf
            success = False
            best_params = initial_guess 
        return best_params, loss, success

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():
    # True parameters and initial condition
    true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
    y0 = jnp.array([10.0, 5.0])
    ts = jnp.linspace(0, 30, 100)

    # Generate noisy data
    noisy_data = generate_noisy_data(true_params, y0, ts)

    # Package data for optimization
    data = (y0, ts, noisy_data)

    # Generate initial guesses
    num_samples = 100
    sobol = qmc.Sobol(d=4, scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    best_params, losses, successes = optimize_batch(initial_guesses, compute_mse, data)

    # Find the best result
    best_index = jnp.argmin(losses)
    best_loss = losses[best_index]
    best_param = best_params[best_index]

    print(f"True parameters: {true_params}")
    print(f"Best parameters: {best_param}")
    print(f"Best loss: {best_loss} at index {best_index}")

if __name__ == "__main__":
    main()

In the commented out part i added the case with the for loop which takes 12 seconds on my mac. The vmap never finishes

patrick-kidger commented 6 days ago

This is quite a large MWE! I'd love to help, but it'd be great if you can condense this down to the most minimal thing that demonstrates your issue first. (Are qmc, tqdm, matplotlib all necessary to reproduce your issue? What about diffrax.diffeqsolve, or will a smaller non-diffrax function suffice? Etc.)

MaAl13 commented 15 hours ago

Hello, i think i boiled it down to a smaller MWE, for me it makes a difference if i use diffrax or not:

import optimistix as optx
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import diffrax as dfx
# JAX configuration
config.update("jax_enable_x64", True)
import time

def lotka_volterra(t, y, parameters):
    prey, predator = y
    α, β, γ, δ = parameters
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.stack([d_prey, d_predator])
    return d_y

def solve(parameters):
    """Solve a single ODE."""
    term = dfx.ODETerm(lotka_volterra)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 100))
    sol = dfx.diffeqsolve(term, solver, 0, 30, 0.1, jnp.array([1,1]), args=parameters, saveat=saveat, adjoint=dfx.DirectAdjoint())
    return sol.ys

# Compute MSE loss
def compute_mse(params, ts):
    predicted = solve(params)
    return jnp.mean((predicted - jnp.zeros_like(predicted))**2)

def optimize_batch(initial_guesses, compute_mse, ts):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = ts, max_steps=5000)
            loss = compute_mse(sol.value, ts)
            best_params = sol.value  
        except:
            loss = jnp.inf
            success = False
            best_params = initial_guess 
        return best_params, loss

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():
    # True parameters and initial condition
    ts = jnp.linspace(0, 30, 100)

    # Generate initial guesses
    num_samples = 10000
    sobol = qmc.Sobol(d=4, scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    best_params, losses = optimize_batch(initial_guesses, compute_mse, ts)

if __name__ == "__main__":
    main()