Open ciupakabra opened 2 years ago
How does it change if you switch to control = UnsafeBrownianPath
? What about if you switch to diffeqsolve(..., stepsize_controller=ConstantStepSize(compile_steps=True))
?
Putting both of those together should be essentially identical to just writing out the solver in plain JAX.
EDIT: actually, I think UnsafeBrownianPath
doesn't support backpropagation here, which it probably should. Let me have a look at this.
Ech, looks like this runs afoul of #176 as well. Which version of JAX are you using? For the purposes of debugging this I'll downgrade. (And #176 will be fixed just as soon as https://github.com/google/jax/pull/13062 lands.)
Hmm, my versions are:
jax 0.3.14
jaxlib 0.3.14
which seem to fix #176 but still give the behaviour described above. I've tried different versions of diffrax but not jax. Any particular jax version you recommend I should try?
EDIT: with stepsize_controller=ConstantStepSize(compile_steps=True)
the speeds improve but the difference is still quite big: integrating terms
and diffusion_term
takes 26s, while drift_term
takes 3s.
2nd EDIT:
Disabling this check
and using UnsafeBrownianPath
for the control has the same problem
Okay, I think I've tracked this down. It's because without the diffusion term, your computation is actually unbatched: your only batched input is key
, but this is unused. So JAX optimises the computation to only run the diffeq for a single batch element, then just broadcasts it out at the end.
But if you have the diffusion term, then you're actually solving 32 different equations at the same time, and the same optimisation can't be made.
How come for a straightforward implementation of the solver JAX manages to optimize this? Changing the sample
function to the below runs in 5s. (This is with the same versions of jax
, jaxlib
and diffrax
as before -- upgrading all three to most recent versions gives quite a big performance boost in all scenarios (any ideas why?) but the differences are still similar: diffrax
ODE term takes 2s, diffrax
SDE w VirtualBrownianTree
takes 15s with compile_steps=True
/ 24s without, diffrax
SDE w UnsafeBrownianMotion
takes 10s with compile_steps=True
/ 14s without and the below implementation still takes 5s.)
def euler_maruyama(drift, diff, num_steps, t0, t1, y0, key):
ts = jnp.linspace(0, 1, num_steps + 1)[:-1]
keys = jrandom.split(key, num_steps)
dt = 1 / num_steps
def _body_fun(y, args):
t, key = args
w = jrandom.normal(key, y.shape)
_y = y + dt * drift(t, y, None) + jnp.sqrt(dt) * diff(t, y, None) * w
return _y, _y
_, path = jax.lax.scan(_body_fun, y0, (ts, keys))
path = jnp.concatenate([y0[None, ...], path], axis=0)
return path
def sample(model, num_steps, dim, key):
drift, diff = model
def f(t, y, args):
return drift(jnp.concatenate([t[None], y]))
def g(t, y, args):
return diff(jnp.concatenate([t[None], y]))**2
y0 = jnp.zeros(dim)
path = euler_maruyama(
f,
g,
num_steps,
0.0,
1.0,
y0,
key,
)
return path
I was assuming that the problem was because the sampling in VirtualBrownianTree
is somewhat complicated, but the same happens with UnsafeBrownianMotion
as well (after disabling the check as in my previous comment) -- so why is JAX unable to optimize it as it does with the above implementation?
Okay, I think I've tracked this down.
Did you do this by just inspecting jaxprs?
Upgrading giving a performance boost -- this is because I've been working to improve the efficiency of Diffrax :)
(And there's more stuff coming in just over the horizon: leave a comment over on https://github.com/google/jax/pull/13184 if you want.)
Yeah, the complexity of sampling the Brownian motion was my first thought as well. (Stuff like key-splitting and key-folding-in is also not that cheap, and was my second thought.) JAX manages to optimise your implementatio for the same reason as it's able to optimise the Diffrax version. You are only vmap
ing wrt key
. If you remove the diffusion term then you're not using key
at all. So the computation is no longer vmap
d. You go from solving a batch of diffeqs to just solving a single diffeq.
In terms of how I tracked this down: nope, no jaxprs. They're sadly not that helpful for debugging anything Diffrax-related. Differential equation solvers are large and complicated enough that they end up being too large to really interpret. (EDIT: although this should be simplified a lot once https://github.com/google/jax/pull/13062 lands.)
In this case I tracked things down by substituting VirtualBrownianTree
for a simple class Control: def evaluate(self, t0, t1, left=True): return t1 - t0
(which produced the fast version) and then bisecting the differences between Control
and VirtualBrownianTree
until I tracked down what the issue was.
If you want to induce a batch dependence without any of this, try doing y0 = y0 + 0.0001 * key[0]
.
Thanks for the quick response!
Are you sure this difference comes only from vmap'ing? The custom implementation above does depend on the key
(since we're solving the full SDE and sampling noise), and doing y0 = y0 + 0.0001 * key[0]
does not change the speed. However, it's still faster than using diffrax
with either UnsafeBrownianMotion
or VirtualBrownianTree
.
I agree that when using diffrax
the difference between integrating the drift_term
and both terms is partly because of the vmap. But even with y0 = y0 + 0.0001 * key[0]
integrating drift_term
is twice as fast as integrating both terms with VirtualBrownianTree
.
But even with
y0 = y0 + 0.0001 * key[0]
integrating drift_term is twice as fast as integrating both terms with VirtualBrownianTree
I realized that this could be simply because we are not evaluating the diff
network. When we change g
to
def g(t, y, args):
return 1.0
there's little difference when using VirtualBrownianTree
. But with this diffusion, integrating both terms with UnsafeBrowianMotion
is the same as integrating just the drift_term
but with y0 = y0 + 0.0001 * key[0]
. So it seems that the performance difference is mainly in the fact that VirtualBrownianTree
is expensive + integrating drift_term
did not vmap
because there was no dependence on the key
.
Perhaps it then makes sense to allow differentiating through UnsafeBrownianMotion
in certain cases? I'm happy to submit a PR for this, but I'm not sure why it's disabled in the first place -- is it because force_bitcast_convert_type
not very stable?
Regarding the difference between the custom implementation above and diffrax
with UnsafeBrownianMotion
-- could it be that jax.lax.scan
is simply faster than bounded_while_loop
? Would it make sense to incorporate something like exact_num_steps
next to max_steps
to use a scan when number of steps is known exactly (as it is in this case)?
So yeah, there's some definite differences here due to whether we're evaluating just drift
, just diff
, or both. In particular these two networks are also of different sizes, and I noticed that this also produced a measurable speed difference when doing drift-vs-diffusion comparisons. (Controlling for everything else.)
FWIW I didn't notice VirtualBrownianTree
adding very much in the way of overhead (~5% relative to a dummy control, once again with every other source of variation fixed), but perhaps YMMV.
The difference between your implementation and bounded_while_loop
: indeed, the extra complexity of bounded_while_loop
(which is reduced but not completely eliminated when using compile_steps=True
) is probably responsible for a fair bit of overhead here. This difference should be eliminated once https://github.com/google/jax/pull/13062 and https://github.com/google/jax/pull/13184 land -- and even better, these should eliminate the overhead even in the case that the number of steps isn't known exactly! The best thing you can do is to leave a comment over on those +1-ing them, to let the JAX maintainers know this is a priority for you.
Differentiating through UnsafeBrownianMotion
: this is disabled because of the checkpointing happening inside bounded_while_loop
. (A necessary part of its magic.) However the reconstruction from the checkpoints may not be perfect down to the least significant floating point bit -- e.g. GPU convolutions are nondeterministic -- and then the use of the floating point number as a PRNG key would result in arbitrarily large changes afterwards. Once again this is an oddity that should vanish once the above PRs land, and we will be able to differentiate with UnsafeBrownianMotion
.
Got it, makes sense! Will leave this open until those PRs roll out
Hi, I was a bit surprised by the difference in speed between backpropogating through ODEs vs SDEs but couldn't find any discussion in the documentation about the time complexity of either. With a quick look, I couldn't find any issues addressing this. In particular, consider this piece of code, where we vary the different terms being integrated:
On my machine locally (macbook m1) integrating
terms
ordiffusion_term
takes around 41s and integratingdrift_term
takes around 5s. What is the reason for this difference? Am I doing something wrong here? Note, that the computation in the diffusion term is simply multiplying BM by a scalar. IsVirtualBrownianTree
the slow part here? I suspect implementing the euler solver in plain jax would give a faster solution -- would that be wrong? Maybe it's worth adding some documentation about this.