Open MaAl13 opened 1 week ago
I have implemented sth like the code below. But it doesn't seem to work:
def optimize_batch(initial_guesses, compute_mse):
def single_optimization(initial_guess):
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
try:
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
sol = optx.minimise(compute_mse, solver = solver, y0 = initial_guess, max_steps = 5000)
loss = compute_mse(sol.value, None)
success = True
#print(loss)
except:
loss = np.inf
success = False
return loss, success
vectorized_optimization = eqx.filter_vmap(single_optimization)
return vectorized_optimization(initial_guesses)
def main():
# Generate initial guesses
num_samples = 1024
sobol = qmc.Sobol(d=len(simulation.flexible_params_indices), scramble=True)
initial_guesses = jnp.array(sobol.random(n=num_samples))
# Run the batch optimization
losses, successes = optimize_batch(initial_guesses, compute_mse)
losses = jnp.where(successes, losses, jnp.inf).astype('float64')
# Find the best result
best_index = jnp.argmin(losses)
best_loss = losses[best_index]
best_params = initial_guesses[best_index]
print(f"Best loss: {best_loss} at index {best_index}")
if __name__ == "__main__":
main()
Yup, it should be possible to vmap over the initial condition.
Unfortunately your example isn't runnable (I don't know what qmc
is).
Hello, thanks for the response. It is scipy.stats.qmc :) Here should be a runnable example code:
import optimistix as optx
import numpy as np
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float
import diffrax as dfx
from contextlib import contextmanager
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
# JAX configuration
config.update("jax_enable_x64", True)
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
yield
def vector_field(
t, y: Float[Array, "2"], parameters: Float[Array, "4"]
) -> Float[Array, "2"]:
prey, predator = y
α, β, γ, δ = parameters
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = jnp.stack([d_prey, d_predator])
return d_y
def solve(
parameters: Float[Array, "4"], y0: Float[Array, "2"], saveat: dfx.SaveAt
) -> Float[Array, "ts"]:
"""Solve a single ODE."""
term = dfx.ODETerm(vector_field)
solver = dfx.Tsit5()
t0 = saveat.subs.ts[0]
t1 = saveat.subs.ts[-1]
dt0 = 0.1
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=parameters,
saveat=saveat,
adjoint=dfx.DirectAdjoint(),
)
return sol.ys
# Generate noisy measurement data
def generate_noisy_data(true_params, y0, ts, noise_level=0.1):
true_solution = solve(true_params, y0, dfx.SaveAt(ts=ts))
key = jr.PRNGKey(0)
noise = jr.normal(key, true_solution.shape) * noise_level
return true_solution + noise
# Compute MSE loss
def compute_mse(params, data):
y0, ts, noisy_data = data
predicted = solve(params, y0, dfx.SaveAt(ts=ts))
return jnp.mean((predicted - noisy_data)**2)
# def single_optimization(initial_guess, data):
# solver = optx.BFGS(rtol=1e-5, atol=1e-5)
# try:
# sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
# loss = compute_mse(sol.value, data)
# best_params = sol.value
# success = True
# except:
# loss = jnp.inf
# success = False
# best_params = initial_guess
# return best_params, loss, success
# def main():
# # True parameters and initial condition
# true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
# y0 = jnp.array([10.0, 5.0])
# ts = jnp.linspace(0, 30, 100)
# # Generate noisy data
# noisy_data = generate_noisy_data(true_params, y0, ts)
# # Package data for optimization
# data = (y0, ts, noisy_data)
# # Generate initial guesses
# num_samples = 100
# sobol = qmc.Sobol(d=4, scramble=True)
# initial_guesses = jnp.array(sobol.random(n=num_samples))
# losses = np.zeros(num_samples)
# successes = []
# solutions = np.zeros_like(initial_guesses)
# # Run the batch optimization
# for i, initial_guess in tqdm(enumerate(initial_guesses), total=num_samples, desc="Optimizing", ncols=100):
# with suppress_warnings():
# try:
# best_params, loss, success = single_optimization(initial_guess, data)
# successes.append(success)
# solutions[i, :] = best_params if success else np.inf
# except Exception as e:
# loss = np.inf
# successes.append(False)
# solutions[i, :] = initial_guess
# losses[i] = loss
# print(f"Lowest loss: {np.nanmin(losses)}")
# print(f"Lowest index: {np.nanargmin(losses)}")
# best_parameters = solutions[np.nanargmin(losses),:]
# # Generate the best-fit trajectory
# best_trajectory = solve(best_parameters, y0, dfx.SaveAt(ts=ts))
# # Plotting
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
# # Plot prey
# ax1.scatter(ts, noisy_data[:, 0], color='red', alpha=0.5, label='Noisy Data (Prey)')
# ax1.plot(ts, best_trajectory[:, 0], color='blue', label='Best Fit (Prey)')
# ax1.set_ylabel('Prey Population')
# ax1.legend()
# ax1.set_title('Prey Population: Noisy Data vs Best Fit')
# # Plot predator
# ax2.scatter(ts, noisy_data[:, 1], color='green', alpha=0.5, label='Noisy Data (Predator)')
# ax2.plot(ts, best_trajectory[:, 1], color='orange', label='Best Fit (Predator)')
# ax2.set_xlabel('Time')
# ax2.set_ylabel('Predator Population')
# ax2.legend()
# ax2.set_title('Predator Population: Noisy Data vs Best Fit')
# plt.tight_layout()
# plt.show()
# print(f"True parameters: {true_params}")
# print(f"Best parameters: {best_parameters}")
# if __name__ == "__main__":
# main()
def optimize_batch(initial_guesses, compute_mse, data):
def single_optimization(initial_guess):
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
try:
sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
loss = compute_mse(sol.value, data)
success = True
best_params = sol.value
except:
loss = jnp.inf
success = False
best_params = initial_guess
return best_params, loss, success
vectorized_optimization = eqx.filter_vmap(single_optimization)
return vectorized_optimization(initial_guesses)
def main():
# True parameters and initial condition
true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
y0 = jnp.array([10.0, 5.0])
ts = jnp.linspace(0, 30, 100)
# Generate noisy data
noisy_data = generate_noisy_data(true_params, y0, ts)
# Package data for optimization
data = (y0, ts, noisy_data)
# Generate initial guesses
num_samples = 100
sobol = qmc.Sobol(d=4, scramble=True)
initial_guesses = jnp.array(sobol.random(n=num_samples))
# Run the batch optimization
best_params, losses, successes = optimize_batch(initial_guesses, compute_mse, data)
# Find the best result
best_index = jnp.argmin(losses)
best_loss = losses[best_index]
best_param = best_params[best_index]
print(f"True parameters: {true_params}")
print(f"Best parameters: {best_param}")
print(f"Best loss: {best_loss} at index {best_index}")
if __name__ == "__main__":
main()
In the commented out part i added the case with the for loop which takes 12 seconds on my mac. The vmap never finishes
This is quite a large MWE! I'd love to help, but it'd be great if you can condense this down to the most minimal thing that demonstrates your issue first. (Are qmc
, tqdm
, matplotlib
all necessary to reproduce your issue? What about diffrax.diffeqsolve
, or will a smaller non-diffrax function suffice? Etc.)
Hello, i think i boiled it down to a smaller MWE, for me it makes a difference if i use diffrax or not:
import optimistix as optx
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import diffrax as dfx
# JAX configuration
config.update("jax_enable_x64", True)
import time
def lotka_volterra(t, y, parameters):
prey, predator = y
α, β, γ, δ = parameters
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = jnp.stack([d_prey, d_predator])
return d_y
def solve(parameters):
"""Solve a single ODE."""
term = dfx.ODETerm(lotka_volterra)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 100))
sol = dfx.diffeqsolve(term, solver, 0, 30, 0.1, jnp.array([1,1]), args=parameters, saveat=saveat, adjoint=dfx.DirectAdjoint())
return sol.ys
# Compute MSE loss
def compute_mse(params, ts):
predicted = solve(params)
return jnp.mean((predicted - jnp.zeros_like(predicted))**2)
def optimize_batch(initial_guesses, compute_mse, ts):
def single_optimization(initial_guess):
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
try:
sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = ts, max_steps=5000)
loss = compute_mse(sol.value, ts)
best_params = sol.value
except:
loss = jnp.inf
success = False
best_params = initial_guess
return best_params, loss
vectorized_optimization = eqx.filter_vmap(single_optimization)
return vectorized_optimization(initial_guesses)
def main():
# True parameters and initial condition
ts = jnp.linspace(0, 30, 100)
# Generate initial guesses
num_samples = 10000
sobol = qmc.Sobol(d=4, scramble=True)
initial_guesses = jnp.array(sobol.random(n=num_samples))
# Run the batch optimization
best_params, losses = optimize_batch(initial_guesses, compute_mse, ts)
if __name__ == "__main__":
main()
Hello,
first of all i wanted to say that this is a really nice library. Then i wanted to ask if it is possible to do with filter_vmap something like a parallel multi start optimization. Meaning that if i want to run the minimization from different starting parameter guesses, can this be parallelized?