stax.serial.apply_fun is not a valid JAX type inside odeint #2920

Closed

skrsna commented 4 years ago

Hi, FWIW, I'm using a self-built jax and jaxlib following instructions from #2083.

I'm trying to do get gradients through an ODE solver. First, I ran into AssertionError issue #2718 and I think I solved it by passing all the arguments directly into odeint. Then I followed instructions to solve another AssertionError issue #2531 by doing vmap of grads instead of grads of vmap . Now I'm getting the following error.

Full trace back.

I'm passing two stax.Serial modules with three Dense layers each as an input to odeint to integrate the Lotka-Volterra ODEs. ufuncs and uparams contains apply functions and params of stax.Serial module.

def lv_UDE(y,t,params,ufuncs,uparams):
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.array([dRdt,dFdt])

I'm trying to get gradients through an odeint w.r.t uparams. Is there a workaround to pass stax.Serial modules as an argument? Thanks in advance.

shoyer commented 4 years ago

Could you please share a full example of how you get this error? Ideally something that I could copy into a terminal and run.

skrsna commented 4 years ago

Hi, I just noticed that even the non vmapped version of a function with stax.serial as an input errors out with the same error message. Here's the full example. Thanks

import jax 
import jax.numpy as np
import numpy as onp
from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental.ode import odeint
from jax.experimental import stax
from functools import partial

def lv(y,t,params):
    original lotka-volterra equations
    R,F = y
    alpha, beta, gamma, theta = params
    dRdt = alpha*R - beta*R*F
    dFdt = gamma*R*F - theta*F
    return np.hstack([dRdt,dFdt])

t = np.linspace(0.,4.,num=1000)
y0 = np.array([0.44249296,4.6280594])

true_y = odeint(partial(lv,params=[1.3,0.9,0.5,1.8]),y0=y0,t=t) #training data generation

def lv_UDE(y,t,params,ufuncs,uparams):
    additional parameters include stax.Serial 
    modules and uparams associated with them
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.hstack([dRdt,dFdt])

#two modules of stax Serial
U1_init, U1 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 
U2_init, U2 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 

key, subkey = random.split(random.PRNGKey(0))

_,U1_params = U1_init(key,(2,)) #inputs of size 2
_,U2_params = U2_init(subkey,(2,))
key,subkey = random.split(subkey)

def get_batch():
    Get batches of inital conditions and 
    times along with true time history
    s = onp.random.choice(onp.arange(1000 - 20, 
                        dtype=onp.int64), 20, replace=False)
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:20]  # (T)
    batch_y = np.stack([true_y[s + i] for i in range(20)])  # (T, M, D)
    return batch_y0, batch_t, batch_y

def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
    Mean absolute loss 
    pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams) # integrate using odeint
    loss = np.mean(np.abs(pred_y-batch_y)) #calculate loss
    return loss

grads = value_and_grad(loss,(5,)) #grads w.r.t uparams 
batch_grad = vmap(grads,(0, None, None, None, None, None)) #vectorize over initial conditions (batch_y0)

grads(y0,t,true_y,[1.3,1.8], [U1,U2], 
      [U1_params,U2_params]) #non vmappped  doesn't work
batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], 
           [U1,U2], [U1_params,U2_params]) #vmap version same error
mattjj commented 4 years ago

Hey @skrsna , thanks for the question!

In your example, it seems the lv_UDE is never called. Is that intentional?

The underlying issue here is that odeint can't take function-valued arguments in *args; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passing ufuncs via the *args of odeint, maybe you can instead just write something like:

def lv_UDE(ufuncs,y,t,params,uparams):  # moved ufuncs to front

odeint(partial(lv_UDE, ufuncs), ...)


mattjj commented 4 years ago

It's possible we could support passing function-valued arguments in *args, but I'm not sure it'd be worth the extra complexity. We could at least raise a better error...

skrsna commented 4 years ago

Hi @mattjj , thanks for the super fast response. My bad I forgot to add lv_UDE while refactoring the code to make it look nice. I'll try your solution and update the issue with the workaround. Thanks again.

mattjj commented 4 years ago

Awesome, glad to hear that might help!

I just pushed #2931 to improve the error message. Now running your test program we get:

TypeError: The contents of odeint *args must be arrays or scalars, but got
<function serial.<locals>.apply_fun at 0x7f17fc69ca70>.

I also improved the docstring from this:

     *args: tuple of additional arguments for `func`.

To this:

    *args: tuple of additional arguments for `func`, which must be arrays
      scalars, or (nested) standard Python containers (tuples, lists, dicts,
      namedtuples, i.e. pytrees) of those types.

To make odeint handle those types in *args automatically, we could try to hoist non-arrays out of *args inside odeint. But maybe we can open a separate issue for that enhancement if it's a high priority for anyone. (@shoyer interested to hear if you have a strong opinion!)

mattjj commented 4 years ago

I'm going to let #2931 close this issue, just so as to keep our issues under control. Let me know if that's a bad idea :)

skrsna commented 4 years ago

Sure, please close the issue. I'm currently trying to try out your suggestions and I'll update the issue with working code just in case if anyone else runs into the same error.

skrsna commented 4 years ago

Hi @mattjj , I tried your solution and it works seamlessly with vmap. Thanks again.