Open tims457 opened 2 years ago
I'm not a Flax user, so take this with a pinch of salt. But probably something like the following.
variables = model.init(...)
def vector_field(t, y, args):
return model.apply(args, y)
diffeqsolve(ODETerm(vector_field), ..., args=variables)
Thanks. I'll give it a go.
Hmm, I want to do something like the following:
import diffrax
import jax
import jax.numpy as jnp
from flax import linen as nn
class NeuralODE(nn.Module):
derivative_net: nn.Module
def __call__(self, coords):
def f(t, y, args):
return self.derivative_net(y)
term = diffrax.ODETerm(f)
solver = diffrax.Dopri5()
solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
return solution.ys
coords = jnp.ones((1, 4))
model = NeuralODE(derivative_net=nn.Dense(4))
rng = jax.random.PRNGKey(0)
params = jax.jit(model.init)(rng, coords)
Yet, this gives me:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (4,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was diffeqsolve at ./.venv/lib/python3.10/site-packages/equinox/jit.py:25 traced for xla_call.
------------------------------
The leaked intermediate value was created on line ./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
...
./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Can't lift sublevels 2 to 1
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Any ideas @patrick-kidger?
I'm afraid not. This looks like something to do with Flax, which I'm not familiar enough with that I can help debug this.
That said I'm sure it's possible. Both Equinox and Diffrax operate at "the samel level as" normal JAX. (In contrast Flax is a wrapper around JAX.)
Thanks for the help. It’s at moments like these when I wish I was still at Google :) I wonder if diffrax is calling jit inside diffeqsolve. That would explain the error. Is there a way to disable that?
Otherwise, is there an way to use Equinox and Flax together? Do you have any examples?
Ah, I figured out a (slightly hacky) way to do this:
class NeuralODE(flax.struct.PyTreeNode):
"""A simple neural ODE."""
encoder: nn.Module
derivative_net: nn.Module
decoder: nn.Module
def init(self, rng, coords):
rng, encoder_rng, derivative_net_rng, decoder_rng = jax.random.split(rng, 4)
coords, encoder_params = self.encoder.init_with_output(encoder_rng, coords)
coords, derivative_net_params = self.derivative_net.init_with_output(derivative_net_rng, coords)
coords, decoder_params = self.decoder.init_with_output(decoder_rng, coords)
return {
"encoder": encoder_params,
"derivative_net": derivative_net_params,
"decoder": decoder_params
}
def apply(self, params, coords):
coords = self.encoder.apply(params["encoder"], coords)
def f(t, y, args):
return self.derivative_net.apply(params["derivative_net"], y)
term = diffrax.ODETerm(f)
solver = diffrax.Euler()
solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
coords = solution.ys
coords = self.decoder.apply(params["decoder"], coords)
return coords
rng = jax.random.PRNGKey(0)
coords = jnp.ones((1, 4))
model = NeuralODE(
encoder=nn.Dense(10),
derivative_net=nn.Dense(10),
decoder=nn.Dense(4))
params = jax.jit(model.init)(rng, coords)
Then you can simply use this like any other nn.Module
:
@jax.jit
def compute_loss(params, coords, true_coords):
preds = model.apply(params, coords)
return jnp.abs(preds - true_coords).sum()
grads = jax.grad(compute_loss)(params, coords, jnp.zeros_like(coords))
This just uses flax.struct.PyTreeNode
instead of eqx.Module
. I didn't want to mix both of them in my codebase. Thanks a lot for the help!
Hurrah! I'm glad you figured this out.
@patrick-kidger I'm using haiku and also faced with the leaked tracer issue.
I was almost clueless for about 1 hour, until I find that comparing to my other programs, this "buggy" program uses the haiku model (a simple MLP) only in the ODETerm , thus the haiku model initialization happens inside diffrax frames.
I workarounded by calling the model with fake compatible data once before calling diffrax (thus the model is already initialized when calling diffrax).
I understand that the haiku or flax way of bridging Jax-style pure function with Pytorch-style module has certain degree of "dark" magic inside, but I don't think there are simple ways to trigger frightening leaked tracer exception.
Could you suggest the root cause of this exception? If it's infeasible to prevent such leak, is it possible to provide a more reasonable error or warning at least?
So Haiku (like Flax) was implemented as a wrapper around JAX, and by-and-large isn't compatible with other libraries in the JAX ecosystem. That's really the root cause -- Haiku assumes that it's only being used in particular ways.
My top recommendation is just to use Equinox instead. This provides a PyTorch-style module without the "dark magic".
If you really want to use Haiku, then probably the best thing to do is to pass your MLP through hk.transform
before using another library. This should transform the Haiku DSL into "normal JAX".
Will
diffrax.diffeqsolve
work inside a Flax linen Module? How would you set up the initialization to use Flax inside ofODETerm
instead of Equinox?