patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Compatibility with Flax #115

Open tims457 opened 2 years ago

tims457 commented 2 years ago

Will diffrax.diffeqsolve work inside a Flax linen Module? How would you set up the initialization to use Flax inside of ODETerm instead of Equinox?

patrick-kidger commented 2 years ago

I'm not a Flax user, so take this with a pinch of salt. But probably something like the following.

variables = model.init(...)

def vector_field(t, y, args):
    return model.apply(args, y)

diffeqsolve(ODETerm(vector_field), ..., args=variables)
tims457 commented 2 years ago

Thanks. I'll give it a go.

ameya98 commented 2 years ago

Hmm, I want to do something like the following:

import diffrax
import jax
import jax.numpy as jnp
from flax import linen as nn

class NeuralODE(nn.Module):
    derivative_net: nn.Module

    def __call__(self, coords):
        def f(t, y, args):
            return self.derivative_net(y)

        term = diffrax.ODETerm(f)
        solver = diffrax.Dopri5()
        solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
        return solution.ys

coords = jnp.ones((1, 4))
model = NeuralODE(derivative_net=nn.Dense(4))
rng = jax.random.PRNGKey(0)
params = jax.jit(model.init)(rng, coords)

Yet, this gives me:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (4,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was diffeqsolve at ./.venv/lib/python3.10/site-packages/equinox/jit.py:25 traced for xla_call.
------------------------------
The leaked intermediate value was created on line ./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
...
./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Can't lift sublevels 2 to 1
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Any ideas @patrick-kidger?

patrick-kidger commented 2 years ago

I'm afraid not. This looks like something to do with Flax, which I'm not familiar enough with that I can help debug this.

That said I'm sure it's possible. Both Equinox and Diffrax operate at "the samel level as" normal JAX. (In contrast Flax is a wrapper around JAX.)

ameya98 commented 2 years ago

Thanks for the help. It’s at moments like these when I wish I was still at Google :) I wonder if diffrax is calling jit inside diffeqsolve. That would explain the error. Is there a way to disable that?

ameya98 commented 2 years ago

Otherwise, is there an way to use Equinox and Flax together? Do you have any examples?

ameya98 commented 2 years ago

Ah, I figured out a (slightly hacky) way to do this:

class NeuralODE(flax.struct.PyTreeNode):
    """A simple neural ODE."""

    encoder: nn.Module
    derivative_net: nn.Module
    decoder: nn.Module

    def init(self, rng, coords):
        rng, encoder_rng, derivative_net_rng, decoder_rng = jax.random.split(rng, 4)
        coords, encoder_params = self.encoder.init_with_output(encoder_rng, coords)
        coords, derivative_net_params = self.derivative_net.init_with_output(derivative_net_rng, coords)
        coords, decoder_params = self.decoder.init_with_output(decoder_rng, coords)

        return {
            "encoder": encoder_params,
            "derivative_net": derivative_net_params,
            "decoder": decoder_params
        }

    def apply(self, params, coords):
        coords = self.encoder.apply(params["encoder"], coords)

        def f(t, y, args):
            return self.derivative_net.apply(params["derivative_net"], y)

        term = diffrax.ODETerm(f)
        solver = diffrax.Euler()
        solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
        coords = solution.ys
        coords = self.decoder.apply(params["decoder"], coords)
        return coords

rng = jax.random.PRNGKey(0)
coords = jnp.ones((1, 4))

model = NeuralODE(
    encoder=nn.Dense(10),
    derivative_net=nn.Dense(10),
    decoder=nn.Dense(4))
params = jax.jit(model.init)(rng, coords)

Then you can simply use this like any other nn.Module:

@jax.jit
def compute_loss(params, coords, true_coords):
    preds = model.apply(params, coords)
    return jnp.abs(preds - true_coords).sum()

grads = jax.grad(compute_loss)(params, coords, jnp.zeros_like(coords))

This just uses flax.struct.PyTreeNode instead of eqx.Module. I didn't want to mix both of them in my codebase. Thanks a lot for the help!

patrick-kidger commented 2 years ago

Hurrah! I'm glad you figured this out.

jjyyxx commented 1 year ago

@patrick-kidger I'm using haiku and also faced with the leaked tracer issue.

I was almost clueless for about 1 hour, until I find that comparing to my other programs, this "buggy" program uses the haiku model (a simple MLP) only in the ODETerm , thus the haiku model initialization happens inside diffrax frames.

I workarounded by calling the model with fake compatible data once before calling diffrax (thus the model is already initialized when calling diffrax).

I understand that the haiku or flax way of bridging Jax-style pure function with Pytorch-style module has certain degree of "dark" magic inside, but I don't think there are simple ways to trigger frightening leaked tracer exception.

Could you suggest the root cause of this exception? If it's infeasible to prevent such leak, is it possible to provide a more reasonable error or warning at least?

patrick-kidger commented 1 year ago

So Haiku (like Flax) was implemented as a wrapper around JAX, and by-and-large isn't compatible with other libraries in the JAX ecosystem. That's really the root cause -- Haiku assumes that it's only being used in particular ways.

My top recommendation is just to use Equinox instead. This provides a PyTorch-style module without the "dark magic".

If you really want to use Haiku, then probably the best thing to do is to pass your MLP through hk.transform before using another library. This should transform the Haiku DSL into "normal JAX".