jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

stax.serial.apply_fun is not a valid JAX type inside odeint #2920

Closed skrsna closed 4 years ago

skrsna commented 4 years ago

Hi, FWIW, I'm using a self-built jax and jaxlib following instructions from #2083.

#
# Name                    Version                   Build  Channel
jax                       0.1.64                    <pip>
jaxlib                    0.1.45                    <pip>

I'm trying to do get gradients through an ODE solver. First, I ran into AssertionError issue #2718 and I think I solved it by passing all the arguments directly into odeint. Then I followed instructions to solve another AssertionError issue #2531 by doing vmap of grads instead of grads of vmap . Now I'm getting the following error.

Full trace back.

``` ----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params]) ~/Code/jax/jax/api.py in batched_fun(*args) 805 _check_axis_sizes(in_tree, args_flat, in_axes_flat) 806 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat, --> 807 lambda: _flatten_axes(out_tree(), out_axes)) 808 return tree_unflatten(out_tree(), out_flat) 809 ~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests) 32 # executes a batched version of `fun` following out_dim_dests 33 batched_fun = batch_fun(fun, in_dims, out_dim_dests) ---> 34 return batched_fun.call_wrapped(*in_vals) 35 36 @lu.transformation_with_aux ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: ~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs) 436 f_partial, dyn_args = argnums_partial(f, argnums, args) 437 if not has_aux: --> 438 ans, vjp_py = _vjp(f_partial, *dyn_args) 439 else: 440 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) ~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs) 1437 if not has_aux: 1438 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) -> 1439 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) 1440 out_tree = out_tree() 1441 else: ~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux) 104 def vjp(traceable, primals, has_aux=False): 105 if not has_aux: --> 106 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) 107 else: 108 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) ~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs) 93 _, in_tree = tree_flatten(((primals, primals), {})) 94 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) ---> 95 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) 96 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) 97 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) ~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type) 435 with new_master(trace_type, bottom=bottom) as master: 436 fun = trace_to_subjaxpr(fun, master, instantiate) --> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 438 assert not env 439 del master ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: ~/Code/jax/jax/api.py in f_jitted(*args, **kwargs) 152 flat_fun, out_tree = flatten_fun(f, in_tree) 153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend, --> 154 name=flat_fun.__name__) 155 return tree_unflatten(out_tree(), out) 156 ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params) 342 name = params.get('name', f.__name__) 343 params = dict(params, name=wrap_name(name, 'jvp')) --> 344 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params) 345 primal_out, tangent_out = tree_unflatten(out_tree_def(), result) 346 return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params) 175 in_pvs, in_consts = unzip2([t.pval for t in tracers]) 176 fun, aux = partial_eval(f, self, in_pvs) --> 177 out_flat = call_primitive.bind(fun, *in_consts, **params) 178 out_pvs, jaxpr, env = aux() 179 env_tracers = map(self.full_raise, env) ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params) 146 else: 147 f, dims_out = batch_subtrace(f, self.master, dims) --> 148 vals_out = call_primitive.bind(f, *vals, **params) 149 return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())] 150 ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 999 if top_trace is None: 1000 with new_sublevel(): -> 1001 outs = primitive.impl(f, *args, **params) 1002 else: 1003 tracers = map(top_trace.full_raise, args) ~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args) 460 461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name): --> 462 compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args)) 463 try: 464 return compiled_fun(*args) ~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args) 219 fun.populate_stores(stores) 220 else: --> 221 ans = call(fun, *args) 222 cache[key] = (ans, fun.stores) 223 return ans ~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs) 477 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] 478 jaxpr, pvals, consts = pe.trace_to_jaxpr( --> 479 fun, pvals, instantiate=False, stage_out=True, bottom=True) 480 481 _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) ~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type) 435 with new_master(trace_type, bottom=bottom) as master: 436 fun = trace_to_subjaxpr(fun, master, instantiate) --> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 438 assert not env 439 del master ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams) 1 @partial(jit, static_argnums=(4,)) 2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams): ----> 3 pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams) 4 loss = np.mean(np.abs(pred_y-batch_y)) 5 return loss ~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args) 152 shape/structure as `y0` except with a new leading axis of length `len(t)`. 153 """ --> 154 return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args) 155 156 @partial(jax.jit, static_argnums=(0, 1, 2, 3)) ~/Code/jax/jax/api.py in f_jitted(*args, **kwargs) 149 dyn_args = args 150 args_flat, in_tree = tree_flatten((dyn_args, kwargs)) --> 151 _check_args(args_flat) 152 flat_fun, out_tree = flatten_fun(f, in_tree) 153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend, ~/Code/jax/jax/api.py in _check_args(args) 1558 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): 1559 raise TypeError("Argument '{}' of type {} is not a valid JAX type" -> 1560 .format(arg, type(arg))) 1561 1562 def _valid_jaxtype(arg): TypeError: Argument '.apply_fun at 0x2b06c3d6f7a0>' of type is not a valid JAX type ```

I'm passing two stax.Serial modules with three Dense layers each as an input to odeint to integrate the Lotka-Volterra ODEs. ufuncs and uparams contains apply functions and params of stax.Serial module.

def lv_UDE(y,t,params,ufuncs,uparams):
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.array([dRdt,dFdt])

I'm trying to get gradients through an odeint w.r.t uparams. Is there a workaround to pass stax.Serial modules as an argument? Thanks in advance.

shoyer commented 4 years ago

Could you please share a full example of how you get this error? Ideally something that I could copy into a terminal and run.

skrsna commented 4 years ago

Hi, I just noticed that even the non vmapped version of a function with stax.serial as an input errors out with the same error message. Here's the full example. Thanks

import jax 
import jax.numpy as np
import numpy as onp
from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental.ode import odeint
from jax.experimental import stax
from functools import partial

def lv(y,t,params):
    """
    original lotka-volterra equations
    """
    R,F = y
    alpha, beta, gamma, theta = params
    dRdt = alpha*R - beta*R*F
    dFdt = gamma*R*F - theta*F
    return np.hstack([dRdt,dFdt])

t = np.linspace(0.,4.,num=1000)
y0 = np.array([0.44249296,4.6280594])

true_y = odeint(partial(lv,params=[1.3,0.9,0.5,1.8]),y0=y0,t=t) #training data generation

def lv_UDE(y,t,params,ufuncs,uparams):
    """
    additional parameters include stax.Serial 
    modules and uparams associated with them
    """
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.hstack([dRdt,dFdt])

#two modules of stax Serial
U1_init, U1 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 
                            stax.Dense(32),stax.Tanh,
                           stax.Dense(1))
U2_init, U2 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 
                            stax.Dense(32),stax.Tanh,
                           stax.Dense(1))

key, subkey = random.split(random.PRNGKey(0))

_,U1_params = U1_init(key,(2,)) #inputs of size 2
_,U2_params = U2_init(subkey,(2,))
key,subkey = random.split(subkey)

def get_batch():
    """
    Get batches of inital conditions and 
    times along with true time history
    """
    s = onp.random.choice(onp.arange(1000 - 20, 
                        dtype=onp.int64), 20, replace=False)
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:20]  # (T)
    batch_y = np.stack([true_y[s + i] for i in range(20)])  # (T, M, D)
    return batch_y0, batch_t, batch_y

def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
    """
    Mean absolute loss 
    """
    pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams) # integrate using odeint
    loss = np.mean(np.abs(pred_y-batch_y)) #calculate loss
    return loss

grads = value_and_grad(loss,(5,)) #grads w.r.t uparams 
batch_grad = vmap(grads,(0, None, None, None, None, None)) #vectorize over initial conditions (batch_y0)

grads(y0,t,true_y,[1.3,1.8], [U1,U2], 
      [U1_params,U2_params]) #non vmappped  doesn't work
batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], 
           [U1,U2], [U1_params,U2_params]) #vmap version same error
mattjj commented 4 years ago

Hey @skrsna , thanks for the question!

In your example, it seems the lv_UDE is never called. Is that intentional?

The underlying issue here is that odeint can't take function-valued arguments in *args; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passing ufuncs via the *args of odeint, maybe you can instead just write something like:

def lv_UDE(ufuncs,y,t,params,uparams):  # moved ufuncs to front
    ...

odeint(partial(lv_UDE, ufuncs), ...)

WDYT?

mattjj commented 4 years ago

It's possible we could support passing function-valued arguments in *args, but I'm not sure it'd be worth the extra complexity. We could at least raise a better error...

skrsna commented 4 years ago

Hi @mattjj , thanks for the super fast response. My bad I forgot to add lv_UDE while refactoring the code to make it look nice. I'll try your solution and update the issue with the workaround. Thanks again.

mattjj commented 4 years ago

Awesome, glad to hear that might help!

I just pushed #2931 to improve the error message. Now running your test program we get:

TypeError: The contents of odeint *args must be arrays or scalars, but got
<function serial.<locals>.apply_fun at 0x7f17fc69ca70>.

I also improved the docstring from this:

     *args: tuple of additional arguments for `func`.

To this:

    *args: tuple of additional arguments for `func`, which must be arrays
      scalars, or (nested) standard Python containers (tuples, lists, dicts,
      namedtuples, i.e. pytrees) of those types.

To make odeint handle those types in *args automatically, we could try to hoist non-arrays out of *args inside odeint. But maybe we can open a separate issue for that enhancement if it's a high priority for anyone. (@shoyer interested to hear if you have a strong opinion!)

mattjj commented 4 years ago

I'm going to let #2931 close this issue, just so as to keep our issues under control. Let me know if that's a bad idea :)

skrsna commented 4 years ago

Sure, please close the issue. I'm currently trying to try out your suggestions and I'll update the issue with working code just in case if anyone else runs into the same error.

skrsna commented 4 years ago

Hi @mattjj , I tried your solution and it works seamlessly with vmap. Thanks again.