Open suargi opened 1 week ago
You have grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
but you don't actually return any auxiliary variables. Setting that to false yields:
Step: 0, Loss: 0.17582178115844727, Computation time: 13.726179122924805
Step: 100, Loss: 0.012010098434984684, Computation time: 0.04791116714477539
Step: 200, Loss: 0.01128536369651556, Computation time: 0.06833648681640625
Step: 300, Loss: 0.006681683007627726, Computation time: 0.03933405876159668
Step: 400, Loss: 0.008453472517430782, Computation time: 0.034162044525146484
Thank you, that solved the issue.
I have tried different hyperparameter combinations (num. epochs, learning rate, num. layers, etc) but I cannot get as accurate results as with the Neural ODE. I am wondering if there is some problem with my code. Would be possible for you to take a look and verify that my implementation is correct? Thank you.
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.tanh,
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
self.func = Func(data_size, width_size, depth, key=fkey)
def __call__(self, ts, coeffs, evolving_out=False):
# Each sample of data consists of some timestamps `ts`, and some `coeffs`
# parameterising a control path. These are used to produce a continuous-time
# input path `control`.
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[1] - ts[0]
y0 = self.initial(control.evaluate(ts[0]))
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts),
)
return solution.ys
# ============================================================================
def _get_data(ts, *, key):
y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)
def f(t, y, args):
x = y / (1 + y)
return jnp.stack([x[1], -x[0]], axis=-1)
solver = diffrax.Tsit5()
dt0 = 0.1
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
)
ys = sol.ys
return ys
def get_data(dataset_size, *, key):
length = 100
ts = jnp.linspace(0, 10, length)
key = jr.split(key, dataset_size)
ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
return ts_broadcasted, ys, coeffs
# ============================================================================
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def main(
dataset_size=256,
batch_size=64,
lr_strategy=(3e-3, 3e-3),
steps_strategy=(500, 500),
length_strategy=(1, 1),
width_size=64,
depth=2,
seed=5678,
plot=True,
print_every=100,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key = jr.split(key, 3)
ts, ys, coeffs = get_data(dataset_size, key=data_key)
_, length_size, data_size = ys.shape
model = NeuralCDE(data_size, width_size, depth, key=model_key)
# Training loop like normal.
#
# Only thing to notice is that up until step 500 we train on only the first 10% of
# each time series. This is a standard trick to avoid getting caught in a local
# minimum.
@eqx.filter_jit # value_and_grad
def loss(model, ti, yi, coeff_i):
y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
# MSE without time column
return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)
grad_loss = eqx.filter_value_and_grad(loss, has_aux=False)
@eqx.filter_jit
def make_step(data_i, model, opt_state):
ti, yi, *coeff_i = data_i
loss, grads = grad_loss(model, ti, yi, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
optim = optax.adam(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
_ts = ts[:, : int(length_size * length)]
_ys = ys[:, : int(length_size * length)]
_coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
for step, data_i in zip(
range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
):
start = time.time()
loss, model, opt_state = make_step(data_i, model, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
if plot:
ts = ts[0, :]
plt.plot(ts, ys[0, :, 1], c="dodgerblue", label="Real")
plt.plot(ts, ys[0, :, 2], c="dodgerblue")
sample_coeffs = tuple(c[-1] for c in coeffs)
pred = model(ts, sample_coeffs, evolving_out=True)
plt.plot(ts, pred[:, 1], c="crimson", label="Model")
plt.plot(ts, pred[:, 2], c="crimson")
plt.legend()
plt.tight_layout()
plt.savefig("neural_ode.png")
plt.show()
return ts, ys, coeffs, model
ts, ys, coeffs, model = main()
I'm probably not familiar enough with Neural CDEs to be able to diagnose issues without substantial investigation. I would recommend checking piece by piece to make sure each of the subroutines is operating as expected, e.g. by comparing to specific known solutions on small problems.
Description
I would like to create a neural CDE for regression. For that, I have taken the example from neural CDE for classification and adapted using the content from neural ODE for regression.
I am encountering some issues which I do not know how to solve. I would appreciate if someone could point me in the right direction. Thank you!
Code
Error
The error originates at line
The error message is quite large to write it down here. To replicate the error, please run the code above. Note: My intention is to compute the MSE between the predicted values and the true values. The variables
y
andy_pred
contain the time series values on the first column. Therefore, for the MSE I only use the last two columns.Specifications
jax 0.4.35 jaxlib 0.4.35 jaxtyping 0.2.34 diffrax 0.6.0 equinox 0.11.8 numpy 2.1.2 optax 0.2.3