jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.02k stars 2.75k forks source link

differentiating through the odeint when using complex numbers results in type errors #3986

Closed annahdo closed 4 years ago

annahdo commented 4 years ago

Hi,

I'm trying to calculate the gradient through the odeint solver. when my variables are real numbers it works but if I make it complex, then I get following error message:

TypeError: scan carry output and input must have identical types, got
(ShapedArray(complex64[1]), ShapedArray(complex64[]), ())
and
(ShapedArray(complex64[1]), ShapedArray(float32[]), ()).

how do I solve this issue?

best

import jax.numpy as np
import jax
from jax.experimental import ode as ode

def rhs(u, t):
    deriv = u
    return deriv

def f(x):
    next_x = ode.odeint(rhs, x, np.array([0.1, 0.2]))
    res = np.sum(np.abs(next_x))
    return res

x = 3. + 1j
grad = jax.grad(f)(x)
shoyer commented 4 years ago

Thanks for the report and the clear example to reproduce it. This looks like a bug to me.

For the record, here is the traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-9f91f66ee8c8> in <module>()
     15 
     16 x = 3. + 1j
---> 17 grad = jax.grad(f)(x)

17 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in grad_f(*args, **kwargs)
    411   @wraps(fun, docstr=docstr, argnums=argnums)
    412   def grad_f(*args, **kwargs):
--> 413     _, g = value_and_grad_f(*args, **kwargs)
    414     return g
    415 

/usr/local/lib/python3.6/dist-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    475     dtype = dtypes.result_type(ans)
    476     tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
--> 477     g = vjp_py(np.ones((), dtype=dtype))
    478     g = g[0] if isinstance(argnums, int) else g
    479     if not has_aux:

/usr/local/lib/python3.6/dist-packages/jax/api.py in _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args)
   1505              "match type of corresponding primal output ({})")
   1506       raise TypeError(msg.format(_dtype(a), dtype))
-> 1507   ans = fun(*args)
   1508   return tree_unflatten(out_tree, ans)
   1509 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in unbound_vjp(pvals, jaxpr, consts, *cts)
    114     cts = tuple(map(ignore_consts, cts, pvals))
    115     dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
--> 116     arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
    117     return map(instantiate_zeros, arg_cts)
    118 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    200         call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
    201         cts_out = get_primitive_transpose(eqn.primitive)(
--> 202             params, call_jaxpr, invals, cts_in, cts_in_avals)
    203       else:
    204         cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,

/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in call_transpose(primitive, params, call_jaxpr, args, ct, _)
    486     new_params = update_params(new_params, map(is_undefined_primal, args),
    487                                [type(x) is not Zero for x in ct])
--> 488   out_flat = primitive.bind(fun, *all_args, **new_params)
    489   return tree_unflatten(out_tree(), out_flat)
    490 primitive_transposes[core.call_p] = partial(call_transpose, call_p)

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1096   if top_trace is None:
   1097     with new_sublevel():
-> 1098       outs = primitive.impl(fun, *args, **params)
   1099   else:
   1100     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    536 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
    537   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 538                                *unsafe_map(arg_spec, args))
    539   try:
    540     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    219       fun.populate_stores(stores)
    220     else:
--> 221       ans = call(fun, *args)
    222       cache[key] = (ans, fun.stores)
    223     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    602   pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
    603   jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 604       fun, pvals, instantiate=False, stage_out=True, bottom=True)
    605   map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
    606   jaxpr = apply_outfeed_rewriter(jaxpr)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    427   with new_master(trace_type, bottom=bottom) as master:
    428     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 429     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    430     assert not env
    431     del master

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    203       else:
    204         cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
--> 205                                                          **eqn.params)
    206     cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
    207     # FIXME: Some invars correspond to primals!

/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in _custom_lin_transpose(cts_out, num_res, bwd, avals_out, *invals)
    605   res, _ = split_list(invals, [num_res])
    606   cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
--> 607   cts_in = bwd.call_wrapped(*res, *cts_out)
    608   cts_in_flat, _ = tree_flatten(cts_in)  # already checked tree structure
    609   return [None] * num_res + cts_in_flat

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

/usr/local/lib/python3.6/dist-packages/jax/experimental/ode.py in _odeint_rev(func, rtol, atol, mxstep, res, g)
    285   init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
    286   (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
--> 287       scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
    288   ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
    289   return (y_bar, ts_bar, *args_bar)

/usr/local/lib/python3.6/dist-packages/jax/lax/lax_control_flow.py in scan(f, init, xs, length, reverse, unroll)
   1228                         # Extract the subtree and avals for the first element of the return tuple
   1229                         out_tree_children[0], jaxpr.out_avals[:out_tree_children[0].num_leaves],
-> 1230                         init_tree, carry_avals)
   1231 
   1232   out = scan_p.bind(*itertools.chain(consts, in_flat),

/usr/local/lib/python3.6/dist-packages/jax/lax/lax_control_flow.py in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
   1873            "got\n{}\nand\n{}.")
   1874     raise TypeError(msg.format(what, tree_unflatten(tree1, avals1),
-> 1875                                tree_unflatten(tree2, avals2)))
   1876 
   1877 

TypeError: scan carry output and input must have identical types, got
(ShapedArray(complex64[1]), ShapedArray(complex64[]), ())
and
(ShapedArray(complex64[1]), ShapedArray(float32[]), ()).
PhilippThoelke commented 4 years ago

The error occurs inside the lax.scan() function when the carry output is checked against the input for same same pytree structure and types. As the error message already shows, the types of the input and output do not match:

(ShapedArray(complex64[1]), ShapedArray(complex64[]), ())
and
(ShapedArray(complex64[1]), ShapedArray(float32[]), ())

The input/carry initialization contains float32 where the output is complex64. The carry input is (containing the erroneous) passed to lax.scan() from the ode._odeint_rev() function: https://github.com/google/jax/blob/4c22e012d27ca865d429492f0e5956791907a981/jax/experimental/ode.py#L290-L292

My suggestion would be to change the 0. to g.dtype.type() to make sure that the second carry input is a zero with the same type as the first input value.

Is there a special reason to initialize the second input value as float32? If not, I will create a PR for this as it makes the code given by @AnnahDo run and does not violate any of the tests for the ODE module.

shoyer commented 4 years ago

Mathematically, t_bar is the gradient with respect to observation time. The observation time is always real valued, which I believe implies that the gradient should always be real-valued, too, according to JAX's autodiff convention complex numbers. So my guess is that allowing t_bar to be complex valued could break things in a different way.

I'm not sure why our current ODE gradient (from the Neural ODE paper) doesn't guarantee this property. My guess is that some sort of mathematical correction needs to be applied to the algorithm, possibly as simple as casting t_bar to be real-valued.

@duvenaud any ideas?

duvenaud commented 4 years ago

To be honest, I've never thought about this, and the correct answer doesn't seem obvious. But what @shoyer suggests sound right, and I wouldn't be above sticking a cast on t_bar and seeing if the gradients match numerically.

PhilippThoelke commented 4 years ago

t_bar is complex after the following line because in the example both the return value of func and g are complex: https://github.com/google/jax/blob/e5c4ccbfc8c6113ddcbbf099918c76af7289be25/jax/experimental/ode.py#L278

Adding the line t_bar = lax.convert_element_type(t_bar, t0_bar.dtype) right after that fixes the error. Is this the kind of cast you had in mind? t0_bar comes from the carry inside the scan function and is initialized as 0. (float32).

In the example jax.grad produces (2.006775-0.60499126j) and numerical gradients with different epsilons are

epsilon=1e-01    (2.006692886352539+0.6049776077270508j)
epsilon=1e-02    (2.0071983337402344+0.6052970886230469j)
epsilon=1e-03    (2.0084381103515625+0.606536865234375j)
epsilon=1e-04    (2.02178955078125+0.6389617919921875j)
epsilon=1e-05    (2.09808349609375+0.6198883056640625j)
epsilon=1e-06    (2.86102294921875+0j)
epsilon=1e-07    0j

This seems to run into some problems with jax using float32 by default but after setting the jax_enable_x64 flag to true, the numerical gradients are

epsilon=1e-01    (2.0067512197545323+0.6049284956391965j)
epsilon=1e-02    (2.0067740844419646+0.6049910193928199j)
epsilon=1e-03    (2.0067743130436355+0.6049916447228298j)
epsilon=1e-04    (2.0067743151486184+0.6049916509098807j)
epsilon=1e-05    (2.006774314455839+0.6049916498440666j)
epsilon=1e-06    (2.0067743200513632+0.6049916496664309j)
epsilon=1e-07    (2.0067742312335213+0.6049916745354267j)

These now match the gradient from jax.grad.

The code I used to check the gradients:

from jax.config import config
config.update('jax_enable_x64', True)

import jax.numpy as np
import jax
from jax.experimental import ode as ode

def rhs(u, t):
    deriv = u
    return deriv

def f(x):
    next_x = ode.odeint(rhs, x, np.array([0.1, 0.2]))
    res = np.sum(np.abs(next_x))
    return res

def grad_numerical(f, x, eps=1e-4):
    return ((f(x + eps / 2) - f(x - eps / 2)) / eps
            + (f(x + eps / 2 * 1j) - f(x - eps / 2 * 1j)) / (eps * 1j))

x = 3. + 1j
grad = jax.grad(f)(x)

print('Gradient from jax.grad:', grad)
print()

print('Numerical estimates:')
for eps in (10 ** -i for i in range(1, 8)):
    print(f'epsilon={eps:.0e}\t {grad_numerical(f, x, eps)}')
shoyer commented 4 years ago

Adding the line t_bar = lax.convert_element_type(t_bar, t0_bar.dtype) right after that fixes the error. Is this the kind of cast you had in mind? t0_bar comes from the carry inside the scan function and is initialized as 0. (float32).

This is exactly what @duvenaud and I were suggesting. I'm glad is seems to check out!

Any interesting in putting together a pull request with new test? To make it slightly more comprehensive, I might suggest also adding a single complex-valued parameter alpha with which to differentiate against as well, i.e., testing the ODE ∂y/∂t = α y.

(If you look at the existing tests in tests/ode_test.py, you'll note that we already have a very similar check_grads() utility to what you wrote here.)

shoyer commented 4 years ago

I'm going to tentatively mark this as "fixed" by #4130.