Open pascal-mueller opened 5 months ago
Ahh the classic
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure <class 'diffrax._term.AbstractTerm'>
I do think there could be a more informative error message here because in my experience 9 times out of 10, this is because there is some shape mismatch in the term input/output. A lot of the time, this can be revealed my just manually inspecting the drift function. In this case, if we just eval print(equations(t_start, y0, None))
we see it errors because there is a shape error ValueError: matmul input operand 1 must have ndim at least 1, but it has ndim 0
. Basically, t is a scalar, but the matmul is expecting something of at least 1 dimension. The fix is just to add a dimension to T.
Here is the full fix:
import jax
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import matplotlib.pyplot as plt
from jax import random
# Define the neural network for the external force using Equinox
class ForceMLP(eqx.Module):
input: eqx.nn.Linear
dense1: eqx.nn.Linear
dense2: eqx.nn.Linear
output: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.input = eqx.nn.Linear(1, 256, key=key1)
self.dense1 = eqx.nn.Linear(256, 256, key=key2)
self.dense2 = eqx.nn.Linear(256, 256, key=key3)
self.output = eqx.nn.Linear(256, 1, key=key4)
def __call__(self, t):
x = self.input(t)
x = jax.nn.tanh(x)
x = self.dense1(x)
x = jax.nn.relu(x)
x = self.dense2(x)
x = jax.nn.relu(x)
F = self.output(x)
return F
# Initialize the neural network
key = random.PRNGKey(0)
force_mlp = ForceMLP(key)
def get_force(t):
return force_mlp(t)
# Define the equations for the PDE
def equations(t, y, args):
position, velocity = y
t = jnp.array([t])
force = get_force(t).squeeze()
# Damped harmonic oscillator equations
damping = 0.1
spring_constant = 1.0
dposition_dt = velocity
dvelocity_dt = -damping * velocity - spring_constant * position + force
return jnp.array([dposition_dt, dvelocity_dt])
# Initial conditions and time span
y0 = jnp.array([1.0, 0.0]) # Initial position and velocity
t_start = 0.0
t_end = 10.0
num_points = 100
print(equations(t_start, y0, None))
# ODE solver using diffrax
solver = dfx.Tsit5() # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))
# Solve the ODE
solution = dfx.diffeqsolve(
term,
solver,
t0=t_start,
t1=t_end,
dt0=0.1,
y0=y0,
saveat=saveAt,
stepsize_controller=stepsize_controller,
)
# Print the solution
ts = solution.ts
ys = solution.ys
plt.plot(ts, ys[:, 0], label="Position")
plt.plot(ts, ys[:, 1], label="Velocity")
plt.xlabel("Time")
plt.ylabel("Values")
plt.legend()
plt.title("Damped Harmonic Oscillator with Neural Network Force")
plt.show()
Thank you @lockwo for the help!
By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!
(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )
Thanks for the help. I am very new to JAX and Diffrax, so I had a bit of problems debugging it because I usually step into my programs with a debugger but I'm not used to the functional and JIT nature of JAX yet.
I realized pretty quickly that there is a "hidden error" behind this error message but I just couldn't figure out what and how to access it but that's mainly due to inexperience.
Thanks a lot
Thank you @lockwo for the help!
By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!
(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )
Maybe classic is a strong word haha, but I've been doing a lot of 1D systems this past week so I saw it a lot when I messed up how I squeezed or unsqueezed things and got used to seeing the message. I will see about adding a more clear error statement about shaping when possible.
I have this code solving a PDE. If I set the force inside
equations
to a numerical value, everything is fine but if I try to replace it with a neural network, I get:Note the breakpoint I did and the type.
The type is what is expected. So what exactly am I doing wrong?
Code: