Closed skrsna closed 4 years ago
Could you please share a full example of how you get this error? Ideally something that I could copy into a terminal and run.
Hi,
I just noticed that even the non vmapped version of a function with stax.serial
as an input errors out with the same error message. Here's the full example. Thanks
import jax
import jax.numpy as np
import numpy as onp
from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental.ode import odeint
from jax.experimental import stax
from functools import partial
def lv(y,t,params):
"""
original lotka-volterra equations
"""
R,F = y
alpha, beta, gamma, theta = params
dRdt = alpha*R - beta*R*F
dFdt = gamma*R*F - theta*F
return np.hstack([dRdt,dFdt])
t = np.linspace(0.,4.,num=1000)
y0 = np.array([0.44249296,4.6280594])
true_y = odeint(partial(lv,params=[1.3,0.9,0.5,1.8]),y0=y0,t=t) #training data generation
def lv_UDE(y,t,params,ufuncs,uparams):
"""
additional parameters include stax.Serial
modules and uparams associated with them
"""
R, F = y
alpha, theta = params
U1, U2 = ufuncs
U1_params, U2_params = uparams
dRdt = alpha*R - U1(U1_params, y)
dFdt = -theta*F + U2(U2_params, y)
return np.hstack([dRdt,dFdt])
#two modules of stax Serial
U1_init, U1 = stax.serial(stax.Dense(32),stax.Tanh,
stax.Dense(32), stax.Tanh,
stax.Dense(32),stax.Tanh,
stax.Dense(1))
U2_init, U2 = stax.serial(stax.Dense(32),stax.Tanh,
stax.Dense(32), stax.Tanh,
stax.Dense(32),stax.Tanh,
stax.Dense(1))
key, subkey = random.split(random.PRNGKey(0))
_,U1_params = U1_init(key,(2,)) #inputs of size 2
_,U2_params = U2_init(subkey,(2,))
key,subkey = random.split(subkey)
def get_batch():
"""
Get batches of inital conditions and
times along with true time history
"""
s = onp.random.choice(onp.arange(1000 - 20,
dtype=onp.int64), 20, replace=False)
batch_y0 = true_y[s] # (M, D)
batch_t = t[:20] # (T)
batch_y = np.stack([true_y[s + i] for i in range(20)]) # (T, M, D)
return batch_y0, batch_t, batch_y
def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
"""
Mean absolute loss
"""
pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams) # integrate using odeint
loss = np.mean(np.abs(pred_y-batch_y)) #calculate loss
return loss
grads = value_and_grad(loss,(5,)) #grads w.r.t uparams
batch_grad = vmap(grads,(0, None, None, None, None, None)) #vectorize over initial conditions (batch_y0)
grads(y0,t,true_y,[1.3,1.8], [U1,U2],
[U1_params,U2_params]) #non vmappped doesn't work
batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8],
[U1,U2], [U1_params,U2_params]) #vmap version same error
Hey @skrsna , thanks for the question!
In your example, it seems the lv_UDE
is never called. Is that intentional?
The underlying issue here is that odeint
can't take function-valued arguments in *args
; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passing ufuncs
via the *args
of odeint
, maybe you can instead just write something like:
def lv_UDE(ufuncs,y,t,params,uparams): # moved ufuncs to front
...
odeint(partial(lv_UDE, ufuncs), ...)
WDYT?
It's possible we could support passing function-valued arguments in *args
, but I'm not sure it'd be worth the extra complexity. We could at least raise a better error...
Hi @mattjj , thanks for the super fast response. My bad I forgot to add lv_UDE
while refactoring the code to make it look nice. I'll try your solution and update the issue with the workaround. Thanks again.
Awesome, glad to hear that might help!
I just pushed #2931 to improve the error message. Now running your test program we get:
TypeError: The contents of odeint *args must be arrays or scalars, but got
<function serial.<locals>.apply_fun at 0x7f17fc69ca70>.
I also improved the docstring from this:
*args: tuple of additional arguments for `func`.
To this:
*args: tuple of additional arguments for `func`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
namedtuples, i.e. pytrees) of those types.
To make odeint
handle those types in *args
automatically, we could try to hoist non-arrays out of *args
inside odeint
. But maybe we can open a separate issue for that enhancement if it's a high priority for anyone. (@shoyer interested to hear if you have a strong opinion!)
I'm going to let #2931 close this issue, just so as to keep our issues under control. Let me know if that's a bad idea :)
Sure, please close the issue. I'm currently trying to try out your suggestions and I'll update the issue with working code just in case if anyone else runs into the same error.
Hi @mattjj , I tried your solution and it works seamlessly with vmap
. Thanks again.
Hi, FWIW, I'm using a self-built jax and jaxlib following instructions from #2083.
I'm trying to do get gradients through an ODE solver. First, I ran into
AssertionError
issue #2718 and I think I solved it by passing all the arguments directly intoodeint
. Then I followed instructions to solve anotherAssertionError
issue #2531 by doingvmap
ofgrads
instead ofgrads
ofvmap
. Now I'm getting the following error.Full trace back.
``` ----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params]) ~/Code/jax/jax/api.py in batched_fun(*args) 805 _check_axis_sizes(in_tree, args_flat, in_axes_flat) 806 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat, --> 807 lambda: _flatten_axes(out_tree(), out_axes)) 808 return tree_unflatten(out_tree(), out_flat) 809 ~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests) 32 # executes a batched version of `fun` following out_dim_dests 33 batched_fun = batch_fun(fun, in_dims, out_dim_dests) ---> 34 return batched_fun.call_wrapped(*in_vals) 35 36 @lu.transformation_with_aux ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: ~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs) 436 f_partial, dyn_args = argnums_partial(f, argnums, args) 437 if not has_aux: --> 438 ans, vjp_py = _vjp(f_partial, *dyn_args) 439 else: 440 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) ~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs) 1437 if not has_aux: 1438 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) -> 1439 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) 1440 out_tree = out_tree() 1441 else: ~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux) 104 def vjp(traceable, primals, has_aux=False): 105 if not has_aux: --> 106 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) 107 else: 108 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) ~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs) 93 _, in_tree = tree_flatten(((primals, primals), {})) 94 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) ---> 95 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) 96 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) 97 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) ~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type) 435 with new_master(trace_type, bottom=bottom) as master: 436 fun = trace_to_subjaxpr(fun, master, instantiate) --> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 438 assert not env 439 del master ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: ~/Code/jax/jax/api.py in f_jitted(*args, **kwargs) 152 flat_fun, out_tree = flatten_fun(f, in_tree) 153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend, --> 154 name=flat_fun.__name__) 155 return tree_unflatten(out_tree(), out) 156 ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params) 342 name = params.get('name', f.__name__) 343 params = dict(params, name=wrap_name(name, 'jvp')) --> 344 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params) 345 primal_out, tangent_out = tree_unflatten(out_tree_def(), result) 346 return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params) 175 in_pvs, in_consts = unzip2([t.pval for t in tracers]) 176 fun, aux = partial_eval(f, self, in_pvs) --> 177 out_flat = call_primitive.bind(fun, *in_consts, **params) 178 out_pvs, jaxpr, env = aux() 179 env_tracers = map(self.full_raise, env) ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 1003 tracers = map(top_trace.full_raise, args) 1004 process = getattr(top_trace, processor) -> 1005 outs = map(full_lower, process(primitive, f, tracers, params)) 1006 return apply_todos(env_trace_todo(), outs) 1007 ~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params) 146 else: 147 f, dims_out = batch_subtrace(f, self.master, dims) --> 148 vals_out = call_primitive.bind(f, *vals, **params) 149 return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())] 150 ~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params) 999 if top_trace is None: 1000 with new_sublevel(): -> 1001 outs = primitive.impl(f, *args, **params) 1002 else: 1003 tracers = map(top_trace.full_raise, args) ~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args) 460 461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name): --> 462 compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args)) 463 try: 464 return compiled_fun(*args) ~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args) 219 fun.populate_stores(stores) 220 else: --> 221 ans = call(fun, *args) 222 cache[key] = (ans, fun.stores) 223 return ans ~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs) 477 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] 478 jaxpr, pvals, consts = pe.trace_to_jaxpr( --> 479 fun, pvals, instantiate=False, stage_out=True, bottom=True) 480 481 _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) ~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type) 435 with new_master(trace_type, bottom=bottom) as master: 436 fun = trace_to_subjaxpr(fun, master, instantiate) --> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 438 assert not env 439 del master ~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 148 gen = None 149 --> 150 ans = self.f(*args, **dict(self.params, **kwargs)) 151 del args 152 while stack: in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams)
1 @partial(jit, static_argnums=(4,))
2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
----> 3 pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams)
4 loss = np.mean(np.abs(pred_y-batch_y))
5 return loss
~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
152 shape/structure as `y0` except with a new leading axis of length `len(t)`.
153 """
--> 154 return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
155
156 @partial(jax.jit, static_argnums=(0, 1, 2, 3))
~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
149 dyn_args = args
150 args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 151 _check_args(args_flat)
152 flat_fun, out_tree = flatten_fun(f, in_tree)
153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
~/Code/jax/jax/api.py in _check_args(args)
1558 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
1559 raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1560 .format(arg, type(arg)))
1561
1562 def _valid_jaxtype(arg):
TypeError: Argument '.apply_fun at 0x2b06c3d6f7a0>' of type is not a valid JAX type
```
I'm passing two
stax.Serial
modules with threeDense
layers each as an input toodeint
to integrate the Lotka-Volterra ODEs.ufuncs
anduparams
contains apply functions and params ofstax.Serial
module.I'm trying to get gradients through an
odeint
w.r.tuparams
. Is there a workaround to passstax.Serial
modules as an argument? Thanks in advance.