patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 132 forks source link

Can't use Equinox inside `term` #446

Open pascal-mueller opened 5 months ago

pascal-mueller commented 5 months ago

I have this code solving a PDE. If I set the force inside equations to a numerical value, everything is fine but if I try to replace it with a neural network, I get:


% python osc.py
Traceback (most recent call last):
  File ".../project_3/osc.py", line 71, in <module>
    solution = dfx.diffeqsolve(
               ^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 327, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 175, in _python_pjit_helper
    attrs_tracked) = _infer_params(jit_info, args, kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 627, in _infer_params
    jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
                                                                         ^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1275, in _pjit_jaxpr
    jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
                                                   ^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1189, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2347, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2370, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
    out = fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/diffrax/_integrate.py", line 781, in diffeqsolve
    raise ValueError(
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure <class 'diffrax._term.AbstractTerm'>

Note the breakpoint I did and the type.

% python osc.py
> /.../osc.py(69)<module>()
-> saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))
(Pdb) type(term)
<class 'diffrax._term.ODETerm'>

The type is what is expected. So what exactly am I doing wrong?

Code:

import jax
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import matplotlib.pyplot as plt
from jax import random

# Define the neural network for the external force using Equinox
class ForceMLP(eqx.Module):
    input: eqx.nn.Linear
    dense1: eqx.nn.Linear
    dense2: eqx.nn.Linear
    output: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.input = eqx.nn.Linear(1, 256, key=key1)
        self.dense1 = eqx.nn.Linear(256, 256, key=key2)
        self.dense2 = eqx.nn.Linear(256, 256, key=key3)
        self.output = eqx.nn.Linear(256, 1, key=key4)

    def __call__(self, t):
        x = self.input(t)
        x = jax.nn.tanh(x)
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        x = jax.nn.relu(x)
        F = self.output(x)
        return F

# Initialize the neural network
key = random.PRNGKey(0)
force_mlp = ForceMLP(key)

def get_force(t):
    return force_mlp(t)

# Define the equations for the PDE
def equations(t, y, args):
    position, velocity = y
    force = get_force(t)

    # Damped harmonic oscillator equations
    damping = 0.1
    spring_constant = 1.0

    dposition_dt = velocity
    dvelocity_dt = -damping * velocity - spring_constant * position + force

    return jnp.array([dposition_dt, dvelocity_dt])

# Initial conditions and time span
y0 = jnp.array([1.0, 0.0])  # Initial position and velocity
t_start = 0.0
t_end = 10.0
num_points = 100

# ODE solver using diffrax
solver = dfx.Tsit5()  # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
breakpoint()
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))

# Solve the ODE
solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t_start,
    t1=t_end,
    dt0=0.1,
    y0=y0,
    saveat=saveAt,
    stepsize_controller=stepsize_controller,
)

# Print the solution
ts = solution.ts
ys = solution.ys

plt.plot(ts, ys[:, 0], label="Position")
plt.plot(ts, ys[:, 1], label="Velocity")
plt.xlabel("Time")
plt.ylabel("Values")
plt.legend()
plt.title("Damped Harmonic Oscillator with Neural Network Force")
plt.show()
lockwo commented 5 months ago

Ahh the classic

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure <class 'diffrax._term.AbstractTerm'>

I do think there could be a more informative error message here because in my experience 9 times out of 10, this is because there is some shape mismatch in the term input/output. A lot of the time, this can be revealed my just manually inspecting the drift function. In this case, if we just eval print(equations(t_start, y0, None)) we see it errors because there is a shape error ValueError: matmul input operand 1 must have ndim at least 1, but it has ndim 0. Basically, t is a scalar, but the matmul is expecting something of at least 1 dimension. The fix is just to add a dimension to T.

Here is the full fix:

import jax
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import matplotlib.pyplot as plt
from jax import random

# Define the neural network for the external force using Equinox
class ForceMLP(eqx.Module):
    input: eqx.nn.Linear
    dense1: eqx.nn.Linear
    dense2: eqx.nn.Linear
    output: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.input = eqx.nn.Linear(1, 256, key=key1)
        self.dense1 = eqx.nn.Linear(256, 256, key=key2)
        self.dense2 = eqx.nn.Linear(256, 256, key=key3)
        self.output = eqx.nn.Linear(256, 1, key=key4)

    def __call__(self, t):
        x = self.input(t)
        x = jax.nn.tanh(x)
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        x = jax.nn.relu(x)
        F = self.output(x)
        return F

# Initialize the neural network
key = random.PRNGKey(0)
force_mlp = ForceMLP(key)

def get_force(t):
    return force_mlp(t)

# Define the equations for the PDE
def equations(t, y, args):
    position, velocity = y
    t = jnp.array([t])
    force = get_force(t).squeeze()

    # Damped harmonic oscillator equations
    damping = 0.1
    spring_constant = 1.0

    dposition_dt = velocity
    dvelocity_dt = -damping * velocity - spring_constant * position + force

    return jnp.array([dposition_dt, dvelocity_dt])

# Initial conditions and time span
y0 = jnp.array([1.0, 0.0])  # Initial position and velocity
t_start = 0.0
t_end = 10.0
num_points = 100

print(equations(t_start, y0, None))

# ODE solver using diffrax
solver = dfx.Tsit5()  # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))

# Solve the ODE
solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t_start,
    t1=t_end,
    dt0=0.1,
    y0=y0,
    saveat=saveAt,
    stepsize_controller=stepsize_controller,
)

# Print the solution
ts = solution.ts
ys = solution.ys

plt.plot(ts, ys[:, 0], label="Position")
plt.plot(ts, ys[:, 1], label="Velocity")
plt.xlabel("Time")
plt.ylabel("Values")
plt.legend()
plt.title("Damped Harmonic Oscillator with Neural Network Force")
plt.show()
Screenshot 2024-06-24 at 11 46 09 AM
patrick-kidger commented 5 months ago

Thank you @lockwo for the help!

By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!

(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )

pascal-mueller commented 5 months ago

Thanks for the help. I am very new to JAX and Diffrax, so I had a bit of problems debugging it because I usually step into my programs with a debugger but I'm not used to the functional and JIT nature of JAX yet.

I realized pretty quickly that there is a "hidden error" behind this error message but I just couldn't figure out what and how to access it but that's mainly due to inexperience.

Thanks a lot

lockwo commented 5 months ago

Thank you @lockwo for the help!

By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!

(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )

Maybe classic is a strong word haha, but I've been doing a lot of 1D systems this past week so I saw it a lot when I messed up how I squeezed or unsqueezed things and got used to seeing the message. I will see about adding a more clear error statement about shaping when possible.