Open FabiJa opened 8 months ago
Considering the final "forcing term" example: try replacing the jax.grad
there with a jax.vmap
.
Then you should be able to pass in a batch-of-points, so that each batch element gets a different forcing term.
Does that help?
Thank you. I will test it and give feedback here.
If I understand it correctly, the calculation of the coefficients for interpolation is happening "on the fly", i.e. during training. If so, it would be nice, to have this separated from the training process, also in terms of modularity, if one wants to change the interpolation scheme.
By the way great work! Astonishing pace of new Jax packages from you :-O
Ok, short update: I opted for an alternative way using the NODE example:
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = jax.vmap(model, in_axes=(0, 0, 0, 0))(ti, yi[:, 0], ts_inputs_i, coeff_array_i)
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_jit
def make_step(ti, yi, model, opt_state):
loss, grads = grad_loss(model, ti, yi)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
...
data_for_batching = (_ys, _ts, ts_inputs_array, coeff_array)
for step, batch_data in zip(
range(steps), dataloader(data_for_batching, batch_size, key=loader_key)
):
yi, ti, ts_inputs_i, coeff_array_i = batch_data
start = time.time()
loss, model_train, opt_state = make_step(ti, yi, model_train, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
I do not know, if this makes sense. At least it seems to work (so far). coeff_array_i
are the coefficient of e.g. backward_hermite_coefficients
.
This looks reasonable to me!
Hi,
sry for my slightly uninformed question, but I am new to the Jax ecosystem.
I have different data sets of measurements with different excitations
u(t)
for one dynamic system, which dynamics I want to learn. So, excitation changes, but the system (ODE->NODE) is the same.I want to use equinox+diffrax to train a neural ODE via batching, which has an external input
u
, meaning the ODE is described byxdot = f(x,u(t))
. The dependencyu(t)
from time is not known explicitly (interpolation from data has to be used) and varies per batch element.Looking in the docu I found the forcing term and the batch training of NODEs. My problem is how to combine both. My first hack was to map each
u(t)
of every batch element to non-overlapping time periods to get a unique mapping from time to the correct input time series. Then I am able to usevmap
directly viaAre there any better options to handle this? Note, that the gradient should not be calculated wrt parameters of the interpolation object representing
u(t)
.Thanks. If there are any questions, let me know.