patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 124 forks source link

Transform Feedforward-Network + solver into a Recurrent-Network #109

Closed SimiPixel closed 2 years ago

SimiPixel commented 2 years ago

Hello Patrick,

let me first quickly motivate my feature request. As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

def select_action(params, state, observation, time):
    apply = neural_network.apply
    state, action = apply(params, state, observation, time)
    return state, action

while True:
    action = select_action(..., observation, env.time)
    observation = env.step(action)

Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

def select_action(params, ode_state, observation, time):
    rhs = lambda x,u: neural_network.apply(params, x, u)
    solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
    return ode_state, solution.x(time)

I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

I would love to hear your input :) Anyways thank you in advance.

patrick-kidger commented 2 years ago

Yep, this is definitely possible. Diffrax is intrinsically differentiable so no special care is needed. Untested, but perhaps something like the following:

import equinox as eqx
import diffrax as dfx
import jax.numpy as jnp

# wraps an MLP to concatenate state and observation together
class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self,  state_size, observation_size, width_size, depth, key):
        in_size = 1 + state_size + observation_size
        self.mlp = eqx.nn.MLP(in_size, state_size, width_size, depth, key=key)

    def __call__(self, t, state, observation):
        in_ = jnp.concatenate([t[None], state, observation])
        return self.mlp(in_)

func =  Func(...)
get_action = eqx.nn.MLP(...)

def select_action(model, state, observation, time):
    func, get_action = model
    prev_time, state = state
    term = dfx.ODETerm(func)
    # specify solver, dt0, stepsize_controller in whatever way you think appropriate
    solver = ...
    dt0 = ...
    stepsize_controller = ...
    sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
                          stepsize_controller=stepsize_controller)
    (state,) = sol.ys
    action = get_action(state)
    state = (time, state)
    return state, action

It's not critical, but as a nice-to-have this uses Equinox as a convenient neural network library.

SimiPixel commented 2 years ago

Thank you! Hopefully i will try to find some time the next days to try to implement this, but will definitely report back!

Maybe i am just getting too excited by access to a new, powerful tool without rewriting any old code. But i feel like offering a transform that does exactly that, so that given the feedforward-network, the solver, (and stepsize-controller), the measurement/action-mapping constructs a new differentiable function with the call-signature of your typical recurrent network is super nice. It not only is conceptually easier (imo?) but also enables plug-and-play integration of neural ODEs in other domains where recurrent neural networks are often already well established (and especially from an API-standpoint).

SimiPixel commented 2 years ago

Doesn't it make more sense to also include the solver state as part of the state of the differential equation? So, i.e. in your example something like this?

stepsize_controller = ...
solver = ...
dt0 = ...
sampling_rate = 100 # Hz

def select_action(model, state, observation):
   func, get_action = model
   term = dfx.ODETerm(func)
   prev_time, state, solver_state, controller_state = state
   time = prev_time + 1/sampling_rate
   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
                         stepsize_controller=stepsize_controller, controller_state=controller_state, solver_state=solver_state)
   # update controller / solver state
   solver_state, controller_state = ...
   (state,) = sol.ys
   action = get_action(state)
   state = (time, state, solver_state, controller_state)
   return state, action
patrick-kidger commented 2 years ago

Yes, it does. If you do this then you should also pass made_jump. Either as the value from the previous step (if your vector field changes smoothly between steps) or as True (if your vector field has jumps between steps).

SimiPixel commented 2 years ago

For completeness let me post my minimal working example. Spoiler: This uses haiku, simply because i am already comfortable with that. Equinox probably would make this more beautiful :)

import diffrax as dfx
import jax.numpy as jnp
from acme.jax import utils 
from functools import partial 
import haiku as hk 
import jax 

sampling_rate = 100 # Hz
stepsize_controller = dfx.ConstantStepSize()
dt0 = 1/sampling_rate
solver = dfx.Euler()

action_size = 3
obs_size = 2
u_dummy = jnp.ones((action_size))

latent_state_size = 20
hidden_layers = [50,50]
@hk.without_apply_rng
@hk.transform_with_state
def rhs(t, u):
   t = jnp.atleast_1d(t)
   x = hk.get_state("x", shape=(latent_state_size,), init=jnp.zeros, dtype=jnp.float32)
   txu = utils.batch_concat((t,x,u), num_batch_dims=0)
   X = hk.nets.MLP(hidden_layers + [latent_state_size])(txu)
   return {"~": {"x": X}}

def haiku2dfx_rhs(rhs):
   def __rhs(params):
      def _rhs(t, x, u):
         # x is simply passed through
         dxdt, x = rhs(params, x, t, u)
         del x 
         return dxdt 
      return _rhs 
   return __rhs 

# this is not great / quite confusing
dxdt = haiku2dfx_rhs(rhs.apply)

@hk.without_apply_rng
@hk.transform  
def measurement_function(x):
   x = utils.batch_concat(x, num_batch_dims=0)
   C = hk.get_parameter("C", shape=(obs_size,x.shape[-1]), dtype=jnp.float32,
      init=lambda shape, dtype: jax.random.normal(hk.next_rng_key(), shape, dtype=dtype))
   return jnp.matmul(C, x)

def gen_init_solver_state(solver: dfx.AbstractSolver, params_rhs, x0):
   term = dfx.ODETerm(dxdt(params_rhs))
   t0=0.0 
   return solver.init(term, t0=t0, t1=t0+dt0, y0=x0, args=u_dummy)

def gen_init_controller_state():
   return dt0

saveat = dfx.SaveAt(t1=True,solver_state=True,controller_state=True,made_jump=True)

def step_fun_dynamics_to_time(params, state, u, time):
   prev_time, x, solver_state, controller_state, made_jump = state
   term = dfx.ODETerm(dxdt(params["rhs"]))

   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, x, args=u,
                         stepsize_controller=stepsize_controller, saveat=saveat,
                         solver_state=solver_state, controller_state=controller_state,
                         made_jump=made_jump
                         )
   x = sol.ys
   x = utils.squeeze_batch_dim(x)
   state = (time, x, sol.solver_state, sol.controller_state, sol.made_jump)
   obs = measurement_function.apply(params["C"], x)
   return state, obs 

def step_fun_dynamics(params, state, u):
   prev_time = state[0]
   return step_fun_dynamics_to_time(params, state, u, prev_time + dt0)

@jax.jit 
@partial(jax.vmap, in_axes=(None, None, 0))
def unrolled_step_fun_dynamics(params, state, us):
   step_fun_dynamics_constraint = lambda state, u: step_fun_dynamics(params, state, u)
   state, obss = jax.lax.scan(step_fun_dynamics_constraint, init=state, xs=us)
   return obss

# initialise parameters
params_rhs, x0 = rhs.init(jax.random.PRNGKey(1), 0.0, u_dummy)
C = measurement_function.init(jax.random.PRNGKey(1), jnp.ones((latent_state_size,)))

params = {
    "rhs": params_rhs,
    "C": C 
}

# initialise step functions state
# (t0, x0, solver_state0, controller_state0, made_jump0)
made_jump0 = False 
init_state = (0.0, x0, gen_init_solver_state(solver, params_rhs, x0), gen_init_controller_state(), made_jump0)

# make prediction
bs=32
T=5.0 
uss = jnp.ones((bs, int(T*sampling_rate), action_size))
obsss = unrolled_step_fun_dynamics(params, init_state, uss)

Thanks Patrick for your help. It works perfectly.