patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 124 forks source link

Batch training of NODE with varying external input (forcing) per batch element #365

Open FabiJa opened 8 months ago

FabiJa commented 8 months ago

Hi,

sry for my slightly uninformed question, but I am new to the Jax ecosystem.

I have different data sets of measurements with different excitations u(t) for one dynamic system, which dynamics I want to learn. So, excitation changes, but the system (ODE->NODE) is the same.

I want to use equinox+diffrax to train a neural ODE via batching, which has an external input u, meaning the ODE is described by xdot = f(x,u(t)). The dependency u(t) from time is not known explicitly (interpolation from data has to be used) and varies per batch element.

Looking in the docu I found the forcing term and the batch training of NODEs. My problem is how to combine both. My first hack was to map each u(t) of every batch element to non-overlapping time periods to get a unique mapping from time to the correct input time series. Then I am able to use vmap directly via

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = jax.vmap(model, in_axes=(0, 0, None))(ti, yi[:, 0], input_concatenated)
    return jnp.mean((yi - y_pred) ** 2)

Are there any better options to handle this? Note, that the gradient should not be calculated wrt parameters of the interpolation object representing u(t).

Thanks. If there are any questions, let me know.

patrick-kidger commented 8 months ago

Considering the final "forcing term" example: try replacing the jax.grad there with a jax.vmap.

Then you should be able to pass in a batch-of-points, so that each batch element gets a different forcing term.

Does that help?

FabiJa commented 8 months ago

Thank you. I will test it and give feedback here.

If I understand it correctly, the calculation of the coefficients for interpolation is happening "on the fly", i.e. during training. If so, it would be nice, to have this separated from the training process, also in terms of modularity, if one wants to change the interpolation scheme.

By the way great work! Astonishing pace of new Jax packages from you :-O

FabiJa commented 8 months ago

Ok, short update: I opted for an alternative way using the NODE example:

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = jax.vmap(model, in_axes=(0, 0, 0, 0))(ti, yi[:, 0], ts_inputs_i, coeff_array_i)
    return jnp.mean((yi - y_pred) ** 2)

  @eqx.filter_jit
  def make_step(ti, yi, model, opt_state):
      loss, grads = grad_loss(model, ti, yi)
      updates, opt_state = optim.update(grads, opt_state)
      model = eqx.apply_updates(model, updates)
      return loss, model, opt_state

...
data_for_batching = (_ys, _ts, ts_inputs_array, coeff_array)
for step, batch_data in zip(
          range(steps), dataloader(data_for_batching, batch_size, key=loader_key)
      ):
        yi, ti, ts_inputs_i, coeff_array_i = batch_data
        start = time.time()
        loss, model_train, opt_state = make_step(ti, yi, model_train, opt_state)
        end = time.time()
        if (step % print_every) == 0 or step == steps - 1:
            print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

I do not know, if this makes sense. At least it seems to work (so far). coeff_array_i are the coefficient of e.g. backward_hermite_coefficients.

patrick-kidger commented 8 months ago

This looks reasonable to me!