Open ToshiyukiBandai opened 11 months ago
A few thoughts. TL;DR: Jax/XLA took care of it.
First, although I'm no programming language design expert, I'm not sure if equinox objects actually fall under the standard OOP paradigm. Frozen data classes are basically functional, since the self
is immutable.
Second, tree_at
is a rather complex function to implemented (if you look at the code, and I encourage to do so), but these sort of operations are often accelerated by XLA since tree_at normally returns a copy but if you x = tree_at(x), then it doesn't actually have to make a copy and can do relatively fast read/write array/pytree operations.
Third, if you inspect what jax is running, you can see these are functions (the equinox one and the non equinox one) are identical to jax:
This is not wholly surprising, since the state tuple is a pytree (tuples are just pytrees https://jax.readthedocs.io/en/latest/pytrees.html) and so the the equinox module. So both to jax are just pytrees (which is one of the big selling points of equinox).
So any differences between execution speed are largely fluctuations I imagine. Here is the code to inspect the jaxprs:
import timeit
import jax
from jax import lax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Float, Array, Bool
import timeit
from jax import make_jaxpr
def update_y(x):
return x**2
## OOP implementation
class State(eqx.Module):
x: Array
y: Array
def __init__(self, N_s):
self.x = jnp.linspace(0, 1, N_s)
self.y = update_y(self.x)
class Model(eqx.Module):
def update_state(self, x, dx):
x = x + dx
y = update_y(x)
return x, y
def time_stepping(state, dx):
new_x, new_y = model.update_state(state.x, dx)
state = eqx.tree_at(lambda t: t.x, state, new_x)
state = eqx.tree_at(lambda t: t.y, state, new_y)
return state, new_y
state = State(1000)
model = Model()
dx_array = jnp.linspace(0.1, 1, 1000)
carry, output = lax.scan(time_stepping, state, dx_array)
time_class = lambda: jax.block_until_ready(lax.scan(time_stepping, state, dx_array))
print(make_jaxpr(time_stepping)(state, dx_array[0]))
print(min(timeit.repeat(time_class, number=1, repeat=10)))
## Functional Programming
jax.clear_caches()
N_s = 1000
x = jnp.linspace(0, 1, N_s)
y = update_y(x)
state_tuple = (x, y)
def update_state(x, dx):
x = x + dx
y = update_y(x)
return x, y
def time_stepping(state, dx):
x, y = state
new_x, new_y = update_state(x, dx)
return (new_x, new_y), new_y
dx_array = jnp.linspace(0.1, 1, 1000)
carry, output = lax.scan(time_stepping, state_tuple, dx_array)
print(make_jaxpr(time_stepping)(state_tuple, dx_array[0]))
time_tuple = lambda: jax.block_until_ready(lax.scan(time_stepping, state_tuple, dx_array))
print(min(timeit.repeat(time_tuple, number=1, repeat=10)))
To expand on this. Something like this:
a, b, c = ...
x = [a, b]
new_x = eqx.tree_at(lambda y: y[0], x, c)
is entirely equivalent to this:
a, b, c = ...
x = [a, b] # unused
new_x = [c, b]
in which you'll note that the only thing we actually create on the last line is new_x
. We never actually make a copy of b
or c
!
To reiterate, eqx.tree_at
never actually makes a copy of an array. It just manipulates the surrounding containers (pytrees) in which they're contained. You can check this yourself:
# after the first snippet above
assert new_x[0] is c # exact same object, no copies made
That's the first half of why this is fast -- there are no array copies.
Next, this kind of pytree manipulation actually vanishes once things are compiled with XLA. All of this handling pytrees is just a way to build a computation graph of arrays; XLA doesn't even see the pytrees at all. You can see this in @lockwo's first image, in which the two computations have the same jaxpr. (Jaxprs are JAX's internal representation of what computation is performed.) This is the second half of why this is fast: all pytree manipulations are a Python-level operation that vanish under compilation.
Other comments:
1.
I'm not sure if equinox objects actually fall under the standard OOP paradigm. Frozen data classes are basically functional, since the self is immutable.
@lockwo is completely correct here. "OOP" refers specifically to mutating classes, typically via their methods. In contrast Equinox modules are an immutable collection of elements -- which "morally speaking" basically just makes them a pretty syntax for a tuple
. :)
lax.scan
, and that internally secretly does a sort-of-jit (that's how it handles the body function), but in generaly you'll find that your computation is slightly slower than if you place a JIT around your whole computation (including the scan).Hi lockwo and Patrick,
Thank you both for your detailed explanations. It makes sense. As a user, I was worried a bit about the overhead of Equinox operations, so it might be good to emphasize somewhere how high performance can be achieved as you described. Thank you!
Hi Equinox community,
I am using equinox to change my program from functional programming to object-oriented programming style. In the example below (updating state variables x and y for each time step), I expected the object-oriented programming using equinox to be slower than the functional programming because of eqx.tree_at. But, it turned out the wall time was comparable. Is this an expected result? eqx.tree_at does not incur any additional overhead? Or XLA took care of everything?