patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 134 forks source link

Computational performance of eqx.tree_at #572

Open ToshiyukiBandai opened 11 months ago

ToshiyukiBandai commented 11 months ago

Hi Equinox community,

I am using equinox to change my program from functional programming to object-oriented programming style. In the example below (updating state variables x and y for each time step), I expected the object-oriented programming using equinox to be slower than the functional programming because of eqx.tree_at. But, it turned out the wall time was comparable. Is this an expected result? eqx.tree_at does not incur any additional overhead? Or XLA took care of everything?

import timeit
import jax
from jax import lax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Float, Array, Bool
import timeit

def update_y(x):
    return x**2

## OOP implementation

class State(eqx.Module):
    x: Array
    y: Array
    def __init__(self, N_s):
        self.x = jnp.linspace(0, 1, N_s)
        self.y = update_y(self.x)

class Model(eqx.Module):
    def update_state(self, x, dx):
        x = x + dx
        y = update_y(x)
        return x, y

def time_stepping(state, dx):
    new_x, new_y = model.update_state(state.x, dx)
    state = eqx.tree_at(lambda t: t.x, state, new_x)
    state = eqx.tree_at(lambda t: t.y, state, new_y)
    return state, new_y

state = State(1000)
model = Model()
dx_array = jnp.linspace(0.1, 1, 1000)

carry, output = lax.scan(time_stepping, state, dx_array)

time_class = lambda: jax.block_until_ready(lax.scan(time_stepping, state, dx_array))

print(min(timeit.repeat(time_class, number=1, repeat=10)))

## Functional Programming

jax.clear_caches()

N_s = 1000
x = jnp.linspace(0, 1, N_s)
y = update_y(x)
state_tuple = (x, y)

def update_state(x, dx):
    x = x + dx
    y = update_y(x)
    return x, y

def time_stepping(state, dx):
    x, y = state
    new_x, new_y = update_state(x, dx)
    return (new_x, new_y), new_y

dx_array = jnp.linspace(0.1, 1, 1000)

carry, output = lax.scan(time_stepping, state_tuple, dx_array)

time_tuple = lambda: jax.block_until_ready(lax.scan(time_stepping, state_tuple, dx_array))

print(min(timeit.repeat(time_tuple, number=1, repeat=10)))
lockwo commented 11 months ago

A few thoughts. TL;DR: Jax/XLA took care of it.

First, although I'm no programming language design expert, I'm not sure if equinox objects actually fall under the standard OOP paradigm. Frozen data classes are basically functional, since the self is immutable.

Second, tree_at is a rather complex function to implemented (if you look at the code, and I encourage to do so), but these sort of operations are often accelerated by XLA since tree_at normally returns a copy but if you x = tree_at(x), then it doesn't actually have to make a copy and can do relatively fast read/write array/pytree operations.

Third, if you inspect what jax is running, you can see these are functions (the equinox one and the non equinox one) are identical to jax:

This is not wholly surprising, since the state tuple is a pytree (tuples are just pytrees https://jax.readthedocs.io/en/latest/pytrees.html) and so the the equinox module. So both to jax are just pytrees (which is one of the big selling points of equinox).

Screenshot 2023-10-23 at 10 03 21 PM

So any differences between execution speed are largely fluctuations I imagine. Here is the code to inspect the jaxprs:


import timeit
import jax
from jax import lax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Float, Array, Bool
import timeit
from jax import make_jaxpr

def update_y(x):
    return x**2

## OOP implementation

class State(eqx.Module):
    x: Array
    y: Array
    def __init__(self, N_s):
        self.x = jnp.linspace(0, 1, N_s)
        self.y = update_y(self.x)

class Model(eqx.Module):
    def update_state(self, x, dx):
        x = x + dx
        y = update_y(x)
        return x, y

def time_stepping(state, dx):
    new_x, new_y = model.update_state(state.x, dx)
    state = eqx.tree_at(lambda t: t.x, state, new_x)
    state = eqx.tree_at(lambda t: t.y, state, new_y)
    return state, new_y

state = State(1000)
model = Model()
dx_array = jnp.linspace(0.1, 1, 1000)

carry, output = lax.scan(time_stepping, state, dx_array)

time_class = lambda: jax.block_until_ready(lax.scan(time_stepping, state, dx_array))

print(make_jaxpr(time_stepping)(state, dx_array[0]))

print(min(timeit.repeat(time_class, number=1, repeat=10)))

## Functional Programming

jax.clear_caches()

N_s = 1000
x = jnp.linspace(0, 1, N_s)
y = update_y(x)
state_tuple = (x, y)

def update_state(x, dx):
    x = x + dx
    y = update_y(x)
    return x, y

def time_stepping(state, dx):
    x, y = state
    new_x, new_y = update_state(x, dx)
    return (new_x, new_y), new_y

dx_array = jnp.linspace(0.1, 1, 1000)

carry, output = lax.scan(time_stepping, state_tuple, dx_array)

print(make_jaxpr(time_stepping)(state_tuple, dx_array[0]))

time_tuple = lambda: jax.block_until_ready(lax.scan(time_stepping, state_tuple, dx_array))

print(min(timeit.repeat(time_tuple, number=1, repeat=10)))
patrick-kidger commented 11 months ago

To expand on this. Something like this:

a, b, c = ...
x = [a, b]
new_x = eqx.tree_at(lambda y: y[0], x, c)

is entirely equivalent to this:

a, b, c = ...
x = [a, b]  # unused
new_x = [c, b]

in which you'll note that the only thing we actually create on the last line is new_x. We never actually make a copy of b or c!

To reiterate, eqx.tree_at never actually makes a copy of an array. It just manipulates the surrounding containers (pytrees) in which they're contained. You can check this yourself:

# after the first snippet above
assert new_x[0] is c  # exact same object, no copies made

That's the first half of why this is fast -- there are no array copies.


Next, this kind of pytree manipulation actually vanishes once things are compiled with XLA. All of this handling pytrees is just a way to build a computation graph of arrays; XLA doesn't even see the pytrees at all. You can see this in @lockwo's first image, in which the two computations have the same jaxpr. (Jaxprs are JAX's internal representation of what computation is performed.) This is the second half of why this is fast: all pytree manipulations are a Python-level operation that vanish under compilation.


Other comments:

1.

I'm not sure if equinox objects actually fall under the standard OOP paradigm. Frozen data classes are basically functional, since the self is immutable.

@lockwo is completely correct here. "OOP" refers specifically to mutating classes, typically via their methods. In contrast Equinox modules are an immutable collection of elements -- which "morally speaking" basically just makes them a pretty syntax for a tuple. :)

  1. Note that you're missing a JIT in your above snippets. This is important for good performance with JAX. In practice you're getting away with it because you have a lax.scan, and that internally secretly does a sort-of-jit (that's how it handles the body function), but in generaly you'll find that your computation is slightly slower than if you place a JIT around your whole computation (including the scan).
ToshiyukiBandai commented 11 months ago

Hi lockwo and Patrick,

Thank you both for your detailed explanations. It makes sense. As a user, I was worried a bit about the overhead of Equinox operations, so it might be good to emphasize somewhere how high performance can be achieved as you described. Thank you!