Open zhengqigao opened 4 months ago
I recommending checking out Jax's docs on benchmarking https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code, the tldr for this example is that:
With the following code I got: 19.4 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
import equinox as eqx
import jax
import diffrax
import jax.numpy as jnp
import time
class MLPeqx(eqx.Module):
layers: list
activation: callable = eqx.static_field()
def __init__(self, hidden_dims):
super().__init__()
tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for I in
range(len(hidden_dims) - 1)]
self.activation = jax.nn.relu
def __call__(self, x):
for i in range(len(self.layers) - 1):
x = self.activation(self.layers[i](x))
x = self.layers[-1](x)
return x
class ODEjax(eqx.Module):
func: MLPeqx
def __init__(self, hidden_dims):
super().__init__()
self.func = MLPeqx(hidden_dims)
def __call__(self, t, y, args=None):
return self.func(y)
def solve_ode(input_x, t, func, cfg):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(func),
cfg['method'],
t0=t[0],
t1=t[-1],
y0=input_x,
dt0=None,
saveat=diffrax.SaveAt(ts=t),
stepsize_controller=diffrax.PIDController(atol=cfg['atol'], rtol=cfg['rtol']),
)
return sol.ys
@eqx.filter_jit
def run_diffrax(hidden_dims, input_x, t, num_t, cfg):
t = jnp.linspace(t[0], t[1], num_t)
func = ODEjax(hidden_dims)
y = jax.vmap(solve_ode, in_axes=(0, None, None, None))(input_x, t, func, cfg)
return y
batch_size = 128
hidden_dims = [100, 100, 100]
input_x = jax.random.normal(jax.random.PRNGKey(0), (128, 100))
_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready()
%%timeit
_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready()
Thanks so much! I have tried on my end and observed similar run-time metrics. I have another follow-up question. Say I first want to run with atol=rtol=1e-5
, and later in my code I want it to run with atol=rtol=1e-4
. I observe again the method run_diffrax
runs slower again when changing from 1e-5
to 1e-4
because of compilation(I guess). Namely,
# first time of atol=rtol=1e-5, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready()
# second time of atol=rtol=1e-5, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready()
# first time of atol=rtol=1e-4, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready()
# second time of atol=rtol=1e-4, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': 1e-5,
'rtol': 1e-5}).block_until_ready(
Is this behavior expected? I wonder if there is a way to compile only once for arbitrary atol=rtol
values, and can always run around millisecond level regardless of atol
and rtol
.
Yes, this behavior is expected. The python floats are getting marked as static by the filtering that happens before jit. You can make them not static by making them jax types (e.g. arrays).
start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': jnp.array(1e-5),
'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': jnp.array(1e-5),
'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': jnp.array(1e-4),
'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
'method': diffrax.Dopri5(),
'atol': jnp.array(1e-4),
'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
run time = 4.057 (sec)
run time = 0.016 (sec)
run time = 0.013 (sec)
run time = 0.013 (sec)
Hi,
I am playing around with
diffrax
's ODE solving functionality. In a nutshell, I define a simple feedforward MLP with random initialization and benchmark the runtime of using it as the temporal derivatives of an ODE. I wrote the following code to record the run-time of ODE solving and got run-time around3.7 sec
, which seems much slower compared to other ODE solver frameworks.I am new to jax and diffrax. What did I miss in my code implemenation?