patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 136 forks source link

Runtime penalty with fields that have expensive inits #252

Open pmelchior opened 1 year ago

pmelchior commented 1 year ago

I was surprised about the runtime of a posterior estimation I was coding. Here is a MWA to reproduce the issue:


import jax.numpy as jnp
import equinox as eqx
import distrax

class Parameter(eqx.Module):
    value: jnp.ndarray
    prior: distrax.Distribution

    def log_prior(self):
        return self.prior.log_prob(self.value)

v = jnp.zeros(10)
mu = jnp.ones(10)
sigma = jnp.ones(10)
p0 = distrax.MultivariateNormalDiag(mu, sigma)
p = Parameter(v, p0)

# test with some data
import optax
from jax import random
key = random.PRNGKey(0)
data = random.normal(key, (10,))

@eqx.filter_value_and_grad
def loss_fn_with_grad(model, data):
    neg_log_like = 0.5 * ((model.value - data)**2).sum()
    return neg_log_like - model.log_prior()

@eqx.filter_jit
def make_step(model, data, opt_state):
    loss, grads = loss_fn_with_grad(model, data)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

learning_rate=1e-1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(p, eqx.is_array))

for step in range(100):
    loss, p, opt_state = make_step(p, data, opt_state)
    loss = loss.item()
    print(f"step={step}, loss={loss}")

The runtime for a single call to make_step is about 75 ms on a single core. This drops to 100 µs when the prior is declared static:

class Parameter(eqx.Module):
    value: jnp.ndarray
    prior: distrax.Distribution = eqx.static_field()

    def log_prior(self):
        return self.prior.log_prob(self.value)

The issue, I think, is that in the first case, filtering creates new instances of prior:

params, static = eqx.partition(p, eqx.is_array)
print(p)
print(params)
print(static)

yields

Parameter(
  value=f32[10],
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x13e9ae9d0>
)
Parameter(
  value=f32[10],
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15b82c400>
)
Parameter(
  value=None,
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15bd2d370>
)

In the case of prior: distrax.Distribution = eqx.static_field() it's the same instance. I suspect that the cost to init the new instances is what's dragging down the performance. This is probably the intended behavior here but I was surprised by it because in neither case prior is treated as a tree leaf, which I thought is what get's altered for each gradient update.

I might be doing this wrong, but if not it would be good to clarify the use of static_field in the docs, where it's generally discouraged.

patrick-kidger commented 1 year ago

Okay, this is interesting. Thanks for opening this issue!


First of all, to establish a baseline. On my machine, I get:

%timeit make_step(p, data, opt_state)
30.4 ms ± 3.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)  # baseline

Now, the dev branch of Equinox removes the ability to specify a full filter specification for how to dynamically/statically partition the tree. (The reality was that this simply wasn't used much, as the default of dynamically-trace-all-arrays and static-everything-else is what we want nearly all the time.) This simplification gives us, amongst other things, a noticeable speed improvement:

11.6 ms ± 426 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)  # dev branch

So, hurrah, that's an improvement that's worth something!

But it's still a fair bit slower than what you were reporting when using static_field. Once again, baseline:

171 µs ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)  # static_field (+ dev branch)

Let's see if we can track down where this is coming from. It's not the flattening of prior:

%timeit jax.tree_util.tree_flatten(p.prior)
1.51 µs ± 33.6 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

But the unflattening does seem very costly:

flat, treedef = jax.tree_util.tree_flatten(p.prior)
%timeit jax.tree_util.tree_unflatten(treedef, flat)
1.7 ms ± 79.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Adding a quick print statement to Distrax's unflattening rule Jittable.tree_unflatten, we see that it prints out five times for a single (compiled) run. Now 5 * 1.7ms isn't far off 11.6ms, so it seems that Distrax's unflattening rule is what is accounting for pretty much the entire performance cost.

This looks like a Distrax bug: I tried writing a quick analogue in Equinox and it unflattens in:

2.56 µs ± 67.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

What are the paths forward?

Option 1

Fix Distrax.

Unfortunately, the library seems semi-unmaintained. (In fact, in the process of debugging this, I came across this this Distrax issue, which seems to indicate that the master version of Distrax has several other issues beyond that discussed here.)

So I'm honestly not sure if Distrax is a great choice of library to be using. Maybe switch away from it, or write your own competing library instead.

Option 2

It's definitely a hack, but just marking the Distrax distribution as a static_field seems to work. This is essentially equivalent to capturing the static'd object via closure, e.g. you can't differentiate it.

Incidentally this only actually works because of a different bug in Distrax. (Those JAX arrays you're passing into it aren't hashable, and static data needs to be hashable.)

But -- as long as that works for you, then go for it.

Option 3

Specifically when running a repeated training loop, then it's possible to "cancel out" the flattening/unflattening that's happening in the loop. (i.e. the unflattening at the end of filter_jit followed by the flattening at the start of filter_jit).

In this case model and opt_state are nontrivial PyTrees, so you can do:

import jax.tree_util as jtu

...

flat_p, treedef_model = jtu.tree_flatten(p)
flat_opt_state, treedef_opt_state = jtu.tree_flatten(opt_state)

@eqx.filter_jit
def make_step(flat_model, data, flat_opt_state):
    model = jtu.tree_unflatten(treedef_model, flat_model)
    opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)
    ... # as before
    flat_model = jtu.tree_leaves(model)
    flat_data = jtu.tree_leaves(data)
    return loss, flat_model, flat_data

for step in range(100):
    loss, flat_p, flat_opt_state = make_step(flat_p, data, flat_opt_state)

This is a generic trick that's useful for every training loop. (Library idea: perhaps it'd be worth writing some kind of PyTorch-Lightning-style training loop library that does this kind of thing automatically.)

This has quite a dramatic effect for us, since unflattening is so costly in our case. With this trick in place, we get down to a tidy:

%timeit make_step(flat_p, data, flat_opt_state)
74.2 µs ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Hopefully that helps!

pmelchior commented 1 year ago

Awesome! Thanks for the suggestions. Option 3 looks like the way to go.

Btw, I'm not at all sold on Distrax, it was just the first thing I tried for a toy model. In practice I'll have my own distribution models, the problem above came up during testing.

adam-hartshorne commented 1 year ago

In your option 3 example, make_step function returns loss, flat_model, **flat_data**, but in step loop the variables are loss, flat_p, **flat_opt_state**.

I presume flat_data = jtu.tree_leaves(data) should have read opt_state = jtu.tree_leaves(opt_state)

patrick-kidger commented 1 year ago

Good catch! You're correct.

lockwo commented 5 months ago

It's still in very preliminary stages, but something you might be interested in is distreqx: https://github.com/lockwo/distreqx. I can't get an exact benchmark since when I run the distrax code on my machine I see a very large ValueError: Mismatch custom node data: error. But I can say

import jax.numpy as jnp
import equinox as eqx
from distreqx import distributions

class Parameter(eqx.Module):
    value: jnp.ndarray
    prior: distributions.AbstractDistribution

    def log_prior(self):
        return self.prior.log_prob(self.value)

v = jnp.zeros(10)
mu = jnp.ones(10)
sigma = jnp.ones(10)
p0 = distributions.MultivariateNormalDiag(mu, sigma)
p = Parameter(v, p0)

# test with some data
import optax
from jax import random
key = random.PRNGKey(0)
data = random.normal(key, (10,))

@eqx.filter_value_and_grad
def loss_fn_with_grad(model, data):
    neg_log_like = 0.5 * ((model.value - data)**2).sum()
    return neg_log_like - model.log_prior()

@eqx.filter_jit
def make_step(model, data, opt_state):
    loss, grads = loss_fn_with_grad(model, data)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

learning_rate=1e-1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(p, eqx.is_array))

_ = make_step(p, data, opt_state)
%%timeit
_ = make_step(p, data, opt_state)

yields much faster than 32 ms: 727 µs ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each).

And if you mark it as static you see 102 µs ± 950 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)