Closed annahdo closed 4 years ago
Thanks for the report and the clear example to reproduce it. This looks like a bug to me.
For the record, here is the traceback:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-7-9f91f66ee8c8> in <module>()
15
16 x = 3. + 1j
---> 17 grad = jax.grad(f)(x)
17 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in grad_f(*args, **kwargs)
411 @wraps(fun, docstr=docstr, argnums=argnums)
412 def grad_f(*args, **kwargs):
--> 413 _, g = value_and_grad_f(*args, **kwargs)
414 return g
415
/usr/local/lib/python3.6/dist-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
475 dtype = dtypes.result_type(ans)
476 tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
--> 477 g = vjp_py(np.ones((), dtype=dtype))
478 g = g[0] if isinstance(argnums, int) else g
479 if not has_aux:
/usr/local/lib/python3.6/dist-packages/jax/api.py in _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args)
1505 "match type of corresponding primal output ({})")
1506 raise TypeError(msg.format(_dtype(a), dtype))
-> 1507 ans = fun(*args)
1508 return tree_unflatten(out_tree, ans)
1509
/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in unbound_vjp(pvals, jaxpr, consts, *cts)
114 cts = tuple(map(ignore_consts, cts, pvals))
115 dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
--> 116 arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
117 return map(instantiate_zeros, arg_cts)
118
/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
200 call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
201 cts_out = get_primitive_transpose(eqn.primitive)(
--> 202 params, call_jaxpr, invals, cts_in, cts_in_avals)
203 else:
204 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in call_transpose(primitive, params, call_jaxpr, args, ct, _)
486 new_params = update_params(new_params, map(is_undefined_primal, args),
487 [type(x) is not Zero for x in ct])
--> 488 out_flat = primitive.bind(fun, *all_args, **new_params)
489 return tree_unflatten(out_tree(), out_flat)
490 primitive_transposes[core.call_p] = partial(call_transpose, call_p)
/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1096 if top_trace is None:
1097 with new_sublevel():
-> 1098 outs = primitive.impl(fun, *args, **params)
1099 else:
1100 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
536 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
537 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 538 *unsafe_map(arg_spec, args))
539 try:
540 return compiled_fun(*args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
219 fun.populate_stores(stores)
220 else:
--> 221 ans = call(fun, *args)
222 cache[key] = (ans, fun.stores)
223 return ans
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
602 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
603 jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 604 fun, pvals, instantiate=False, stage_out=True, bottom=True)
605 map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
606 jaxpr = apply_outfeed_rewriter(jaxpr)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
427 with new_master(trace_type, bottom=bottom) as master:
428 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 429 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
430 assert not env
431 del master
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
203 else:
204 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
--> 205 **eqn.params)
206 cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
207 # FIXME: Some invars correspond to primals!
/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in _custom_lin_transpose(cts_out, num_res, bwd, avals_out, *invals)
605 res, _ = split_list(invals, [num_res])
606 cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
--> 607 cts_in = bwd.call_wrapped(*res, *cts_out)
608 cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
609 return [None] * num_res + cts_in_flat
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
/usr/local/lib/python3.6/dist-packages/jax/experimental/ode.py in _odeint_rev(func, rtol, atol, mxstep, res, g)
285 init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
286 (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
--> 287 scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
288 ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
289 return (y_bar, ts_bar, *args_bar)
/usr/local/lib/python3.6/dist-packages/jax/lax/lax_control_flow.py in scan(f, init, xs, length, reverse, unroll)
1228 # Extract the subtree and avals for the first element of the return tuple
1229 out_tree_children[0], jaxpr.out_avals[:out_tree_children[0].num_leaves],
-> 1230 init_tree, carry_avals)
1231
1232 out = scan_p.bind(*itertools.chain(consts, in_flat),
/usr/local/lib/python3.6/dist-packages/jax/lax/lax_control_flow.py in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
1873 "got\n{}\nand\n{}.")
1874 raise TypeError(msg.format(what, tree_unflatten(tree1, avals1),
-> 1875 tree_unflatten(tree2, avals2)))
1876
1877
TypeError: scan carry output and input must have identical types, got
(ShapedArray(complex64[1]), ShapedArray(complex64[]), ())
and
(ShapedArray(complex64[1]), ShapedArray(float32[]), ()).
The error occurs inside the lax.scan()
function when the carry output is checked against the input for same same pytree structure and types. As the error message already shows, the types of the input and output do not match:
(ShapedArray(complex64[1]), ShapedArray(complex64[]), ())
and
(ShapedArray(complex64[1]), ShapedArray(float32[]), ())
The input/carry initialization contains float32
where the output is complex64
. The carry input is (containing the erroneous) passed to lax.scan()
from the ode._odeint_rev()
function:
https://github.com/google/jax/blob/4c22e012d27ca865d429492f0e5956791907a981/jax/experimental/ode.py#L290-L292
My suggestion would be to change the 0.
to g.dtype.type()
to make sure that the second carry input is a zero with the same type as the first input value.
Is there a special reason to initialize the second input value as float32
? If not, I will create a PR for this as it makes the code given by @AnnahDo run and does not violate any of the tests for the ODE module.
Mathematically, t_bar
is the gradient with respect to observation time. The observation time is always real valued, which I believe implies that the gradient should always be real-valued, too, according to JAX's autodiff convention complex numbers. So my guess is that allowing t_bar
to be complex valued could break things in a different way.
I'm not sure why our current ODE gradient (from the Neural ODE paper) doesn't guarantee this property. My guess is that some sort of mathematical correction needs to be applied to the algorithm, possibly as simple as casting t_bar
to be real-valued.
@duvenaud any ideas?
To be honest, I've never thought about this, and the correct answer doesn't seem obvious. But what @shoyer suggests sound right, and I wouldn't be above sticking a cast on t_bar
and seeing if the gradients match numerically.
t_bar
is complex after the following line because in the example both the return value of func
and g
are complex: https://github.com/google/jax/blob/e5c4ccbfc8c6113ddcbbf099918c76af7289be25/jax/experimental/ode.py#L278
Adding the line t_bar = lax.convert_element_type(t_bar, t0_bar.dtype)
right after that fixes the error. Is this the kind of cast you had in mind? t0_bar
comes from the carry inside the scan function and is initialized as 0.
(float32).
In the example jax.grad
produces (2.006775-0.60499126j)
and numerical gradients with different epsilons are
epsilon=1e-01 (2.006692886352539+0.6049776077270508j)
epsilon=1e-02 (2.0071983337402344+0.6052970886230469j)
epsilon=1e-03 (2.0084381103515625+0.606536865234375j)
epsilon=1e-04 (2.02178955078125+0.6389617919921875j)
epsilon=1e-05 (2.09808349609375+0.6198883056640625j)
epsilon=1e-06 (2.86102294921875+0j)
epsilon=1e-07 0j
This seems to run into some problems with jax using float32 by default but after setting the jax_enable_x64
flag to true, the numerical gradients are
epsilon=1e-01 (2.0067512197545323+0.6049284956391965j)
epsilon=1e-02 (2.0067740844419646+0.6049910193928199j)
epsilon=1e-03 (2.0067743130436355+0.6049916447228298j)
epsilon=1e-04 (2.0067743151486184+0.6049916509098807j)
epsilon=1e-05 (2.006774314455839+0.6049916498440666j)
epsilon=1e-06 (2.0067743200513632+0.6049916496664309j)
epsilon=1e-07 (2.0067742312335213+0.6049916745354267j)
These now match the gradient from jax.grad
.
The code I used to check the gradients:
from jax.config import config
config.update('jax_enable_x64', True)
import jax.numpy as np
import jax
from jax.experimental import ode as ode
def rhs(u, t):
deriv = u
return deriv
def f(x):
next_x = ode.odeint(rhs, x, np.array([0.1, 0.2]))
res = np.sum(np.abs(next_x))
return res
def grad_numerical(f, x, eps=1e-4):
return ((f(x + eps / 2) - f(x - eps / 2)) / eps
+ (f(x + eps / 2 * 1j) - f(x - eps / 2 * 1j)) / (eps * 1j))
x = 3. + 1j
grad = jax.grad(f)(x)
print('Gradient from jax.grad:', grad)
print()
print('Numerical estimates:')
for eps in (10 ** -i for i in range(1, 8)):
print(f'epsilon={eps:.0e}\t {grad_numerical(f, x, eps)}')
Adding the line
t_bar = lax.convert_element_type(t_bar, t0_bar.dtype)
right after that fixes the error. Is this the kind of cast you had in mind?t0_bar
comes from the carry inside the scan function and is initialized as0.
(float32).
This is exactly what @duvenaud and I were suggesting. I'm glad is seems to check out!
Any interesting in putting together a pull request with new test? To make it slightly more comprehensive, I might suggest also adding a single complex-valued parameter alpha
with which to differentiate against as well, i.e., testing the ODE ∂y/∂t = α y
.
(If you look at the existing tests in tests/ode_test.py
, you'll note that we already have a very similar check_grads()
utility to what you wrote here.)
I'm going to tentatively mark this as "fixed" by #4130.
Hi,
I'm trying to calculate the gradient through the odeint solver. when my variables are real numbers it works but if I make it complex, then I get following error message:
how do I solve this issue?
best