google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.28k stars 2.68k forks source link

Auxiliary arguments inside lax.root's solve #1448

Closed shoyer closed 4 years ago

shoyer commented 4 years ago

Consider the signature of lax.root, for simplicity omitting the tangent_solve argument:

def root(f, initial_guess, solve):
  """Differentiably solve for a roots of a function.

  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of root() are defined with respect to closed-over variables from
  the provided function f.

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::

        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)

  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """

The essence of root is that it calculates solve(f, initial_guess) with a custom JVP rule defined via implicit differentiation of the function f.

An implicit requirement of lax.root is that solve(f, initial_guess) must be a pure function, without any closed over variables. This would suffice for the generic scipy.optimize.root, but unfortunately, we really do want data dependent solvers in many interesting cases. How an we use these with lax.root, or with the similarly designed lax.linear_solve?

Unfortunately, as soon as we turn solve into a closure, it breaks when we use tracers, e.g., for jit:

from functools import partial
import jax
from jax import lax
import jax.numpy as np

simple_root = partial(lax.root, tangent_solve=None)

def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, np.zeros_like(b), lambda *args: x0)

a = np.eye(2)
b = np.ones((2,))
linear_solve(a, b)  # DeviceArray([1., 1.], dtype=float32)
jax.jit(linear_solve)(a, b)  # TypeError: No constant handler for type: <class 'jax.interpreters.partial_eval.JaxprTracer'>

The problem is that our closed over tracer leaks into the inside of the root primitive.

One answer is to compute the solution outside of solve, and simply pass it in as the initial_guess, e.g.,

def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, x0, lambda f, x0: x0,)

This version works! In fact, it suggests that perhaps we should have a different, simpler interface for lax.root:

def root(f, x):
  """Define gradients via implicit differentiation.

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    x: zero of f, i.e., all(f(x) == 0).

  Returns:
    x, but with gradients defined via implicit differentiation of f.
  """

Unfortunately, this introduces another problem: if we calculate x with non-differentiable primitives, we won’t even get to calling root. JAX will raise an error about undefined JVP rules, e.g., for while_loop. This is why we added the solve argument to root in the first place.

I see several possible ways to resolve this, none of which are entirely satisfactory:

  1. Add extra optional arguments to lax.root, for explicitly passing auxiliary argument into solve. These will get passed into the root primitive directly, allowing the closure problem to be side-stepped, but it’s still surprising that you can’t use a closure in solve.
  2. Change the signature of root to root(f, x), removing the solve argument. Encourage liberal use of lax.stop_gradients to avoid attempting to compute uncomputable or expensive to compjute JVP terms. Maybe we can make this easier by adding either a higher order function or context manager to stop gradients, e.g.,lax.stop_grad(lax.while_loop)(...) or with lax.disable_gradients(): ....
  3. Somehow evaluate solve into a JAXpr inside root (the function, not the primitive) without evaluating it’s JVP. This looks sort of like the context manager solution, except contained inside root.
shoyer commented 4 years ago

linear_solve(matvec, b, solve) is another interesting case. There are common situations where it makes sense to do some pre-computation for creating the function solve that is used for computing the linear solve, e.g., creating a preconditioner or factorization.

Because we often use the same solve multiple times, e.g., in the forward and backwards pass for symmetric matrices, it doesn't make sense to supply an explicit solution x to linear_solve. We'll still want to keep the inverse function solve.

This suggests to me that we'll need to evaluate solve to a JAXpr in linear_solve before calling the primitive. We can use stop_gradient() on the inputs into solve, but I suspect users will need to make their own use of stop_gradient() if they use closed-over variables, unless we make use of a solution like (3) above.

shoyer commented 4 years ago

I guess the natural question with the context manager, is what happens when you have a nested call that requires tracing gradients, e.g.,

with lax.disable_gradients():
  y = jax.grad(f)(g(x))

Ideally this should only remove one "layer" of gradients, i.e., gradients should be removed from g(x) and jax.grad(f) but not inside the evaluation of f.

gehring commented 4 years ago

@shoyer

Because we often use the same solve multiple times, e.g., in the forward and backwards pass for symmetric matrices, it doesn't make sense to supply an explicit solution x to linear_solve. We'll still want to keep the inverse function solve.

I'm not sure I understand what you are saying. Would you mind expanding your linear_solve example (maybe with a few lines of pseudo-code)? Me and some collaborator have been working on some instances of implicit differentiation using JAX. We've been pondering about how best to design various APIs to stay as general as possible so this discussion seems very relevant to us!

shoyer commented 4 years ago

Reusing a Cholesky factorization for the backwards/forwards passes inside scipy.linalg.solve is probably a good example: https://nbviewer.jupyter.org/gist/shoyer/7fa06b95839de596cf3a649ac7217919

Here's what the core logic looks like:

def symmetric_solve(a, b):
  cholesky, lower = jax.scipy.linalg.cho_factor(jax.lax.stop_gradient(a))
  result, = symmetric_solve_p.bind(a, b, cholesky, lower=lower)
  return result

def _symmetric_solve_abstract_eval(a, b, cholesky, lower):
  return (b,)

def _symmetric_solve_impl(a, b, cholesky, lower):
  return (jax.scipy.linalg.cho_solve((cholesky, lower), b),)

def _symmetric_solve_jvp(primals, tangents, lower):
  # A x - b = 0
  # ∂A x + A ∂x - ∂b = 0
  # ∂x = A^{-1} (∂b - ∂A x)
  a, b, cholesky = primals
  tangent_a, tangent_b, _ = tangents
  x, = symmetric_solve_p.bind(a, b, cholesky, lower=lower)
  if tangent_a is ad_util.zero:
    rhs = tangent_b
  elif tangent_b is ad_util.zero:
    rhs = -np.dot(tangent_a, x)
  else:
    rhs = tangent_b - np.dot(tangent_a, x)
  tangent_x, = symmetric_solve_p.bind(a, rhs, cholesky, lower=lower)
  return (x,), (tangent_x,)

def _symmetric_solve_transpose_rule(cotangents, a, b, cholesky, lower):
  assert b is jax.ad.undefined_primal
  cotangent_x, = cotangents
  # Note: does not seem to be quite right yet -- see failing check_grads below 
  cotangent_b, = symmetric_solve_p.bind(a, cotangent_x, cholesky, lower=lower)
  return (None, cotangent_b, None)

symmetric_solve_p = core.Primitive('symmetric_solve')
symmetric_solve_p.multiple_results = True

symmetric_solve_p.def_impl(_symmetric_solve_impl)

symmetric_solve_p.def_abstract_eval(_symmetric_solve_abstract_eval)

ad.primitive_jvps[symmetric_solve_p] = _symmetric_solve_jvp

xla.translations[symmetric_solve_p] = xla.lower_fun(_symmetric_solve_impl, instantiate=True)

ad.primitive_transposes[symmetric_solve_p] = _symmetric_solve_transpose_rule

# TODO: batching rule

This turns out to produce a VJP rule that is significantly faster than what JAX currently produces for scipy.linalg.solve with sym_pos=True, e.g., 1.16 ms vs 4.58 ms on a CPU in my microbenchmark.

Note that most of this logic in the JVP/VJP rules is generic, independent of the details of how the linear solve is performed. I'd like to write a linear_solve primitive that JAX can use internally to implement these rules for scipy.linalg.solve, as well as more general forms of linear solve that may not use explicit matrices.

shoyer commented 4 years ago

To be more specific, my hope is that we could equivalently write symmetric_solve with all the complexity hidden inside lax.linear_solve:

def symmetric_solve(a, b):
  factors = scipy.linalg.cho_factor(lax.stop_gradient(a))
  def solve(matvec, b):
    return scipy.linalg.cho_solve(factors, b)
  return lax.linear_solve(partial(np.dot, a), b, solve, symmetric=True)
jekbradbury commented 4 years ago

It might be worth trying to write root as a full-fledged higher-order primitive that binds subjaxprs for the function and solver. That way (at least in principle) you could support closing over tracers in either of those functions (but you’d also have the same implementation complexity as the other higher-order control primitives).

gehring commented 4 years ago

@shoyer Thank you for the detailed answer! This made everything much clearer for me.

The "full-fledged high-order primitive" solution @jekbradbury proposed feels like it would be the penultimate solution but it's not clear to me if it is necessary or if the added complexity would outweigh the benefits. I like the simplicity of the API in @shoyer solution 2., if it can be done without having the user need to manually specify stop_gradient or use a similar context manager.

Given how simple it is to use custom_gradient, defjvp, etc., a very flexible low-level/primitive API to handle this might not provide enough benefits to outweigh the extra complexity, docs, maintenance, possible confusion. The simplified root idea is lightweight and specialized, and I wouldn't expect it to cause much confusion. You could implement simple_root as something like lax.defroot, where lax.defroot(f, x) is the identity function returning x but defines the derivative of the output through implicit differentiation under the assumption that f(x) = 0. If a high-level root solving method is still desired, it could be placed in an appropriate sub-package, e.g., jax.optimize.root, and wrap the relevant low-level functions. Though, these are just my initial impressions, I could see arguments for preferring other solutions.

shoyer commented 4 years ago

I am also leaning towards replacing root with a simpler primitive which only takes the arguments (f, x). Given that there appears to be no way to avoid evaluating solve if we put it in the arguments, I don't see much advantage in including it.

Maybe lax.implicit_gradient() would be a good name for it? We'll still need a tangent_solve argument (though we will evaluate its jaxpr so it's safe to use a closure), so the signature would be lax.implicit_gradient(f, x, tangent_solve).

I am indeed still interesting in jax.scipy.optimize.root, but that's definitely a more complex beast.

shoyer commented 4 years ago

My implementation of linear_solve can now handle the closure over a factorization case: https://github.com/google/jax/pull/1402. My symmetric_solve snippet from above is one the test cases.

gehring commented 4 years ago

Maybe lax.implicit_gradient() would be a good name for it? We'll still need a tangent_solve argument (though we will evaluate its jaxpr so it's safe to use a closure), so the signature would be lax.implicit_gradient(f, x, tangent_solve).

I'm 100% on board with lax.implicit_gradient. I think it is important to distinguish the naming of functions that actually solve things from those that define autodiff rules. The purpose of defining differentiation rules this way might be obvious for us, but, in my experience, concepts related to implicit differentiation are really non-obvious to a large portion of the comp sci community.

(this is probably a consequence of widely available backprop-based ml frameworks which have left many users under the false impression that backprop through everything is the only way to do differentiation)

shoyer commented 4 years ago

I think it is important to distinguish the naming of functions that actually solve things from those that define autodiff rules.

Yes, this does seem like a good practice. Maybe we should rename lax.linear_solve to something like lax.differentiable_linear_solve then? (This function does actually apply the linear solve, the key feature being that it makes any arbitrary linear solve differentiable.)

I am OK with a clunky name, this is a clunky low-level power feature :)

gehring commented 4 years ago

Maybe we should rename lax.linear_solve to something like lax.differentiable_linear_solve then?

I think that lax.differentiable_linear_solve with explicitly saying why this function exists in the docs is a perfectly good solution. In that case, I wouldn't expect any new users to use it without strongly suspecting that the purpose of this function is a bit more subtle and related to autodiff. I think this part of your reply would be a nice addition to the docs :)

[...] the key feature being that it makes any arbitrary linear solve differentiable.

shoyer commented 4 years ago

https://github.com/google/jax/pull/1550 includes implementations of two versions of the lax.root functionalit:

In both cases, arbitrary functions with closed over variables are supported. They are basically interchangeable.

Here's what these APIs would look like in practice when solve is non-differentiable:

def sqrt_cubed(x): f = lambda y: y 2 - x 3 return lax.custom_root(f, 0.0, binary_search, tangent_solve)

- With `define_implicit_gradient`:
```python
def binary_search(func, ...):
  ...

def sqrt_cubed(x):
  f = lambda y: y ** 2 - x ** 3
  y = binary_search(lax.stop_gradient_fun(f))
  return lax.define_implicit_gradient(f, y, tangent_solve)

define_implicit_gradient has a simpler signature (4 vs 3 arguments) and lets you avoid needing to make up a value for initial_guess, but the downside is that you have to use a helper function like stop_gradient_fun to avoid trying to calculating the JVP of lax.while_loop.

For another data point, here's how each could be written in terms of the other:

def define_implicit_gradient(f, x, tangent_solve):
  solve = lambda f, x0: x0
  return lax.custom_root(f, x, solve, tangent_solve)

def custom_root(f, initial_guess, solve, tangent_solve):
  x = solve(lax.stop_gradient_fun(f), initial_guess)
  return lax.define_implicit_gradient(f, x, tangent_solve)

Any thoughts on which API feels more intuitive/usable? They are similar enough that I don't think we'll want both, though I guess we could if really desired.

gehring commented 4 years ago

Any thoughts on which API feels more intuitive/usable? They are similar enough that I don't think we'll want both, though I guess we could if really desired.

Both of those would be completely usable, imo. I agree that committing to only one of them is best.

I think I prefer define_implicit_gradient but the requirement for manually adding stop_gradient_fun might make it a bit clunky and prone to errors. Is there no way to have define_implicit_gradient (or maybe some higher level wrapper) handle this automatically?

To help me understand a bit better, is there a practical difference between these two?

def custom_root1(f, initial_guess, solve, tangent_solve):
  x = solve(lax.stop_gradient_fun(f), initial_guess)
  return lax.define_implicit_gradient(f, x, tangent_solve)

def custom_root2(f, initial_guess, solve, tangent_solve):
  x = lax.stop_gradient(solve(f, initial_guess))
  return lax.define_implicit_gradient(f, x, tangent_solve)

I think I'm a bit more confused than I thought as to what causes the requirement for stop_gradient_fun and why it is just an issue for define_implicit_gradient.

pierrelux commented 4 years ago

I think that this would be a good interface. I also agree with @gehring that the "stop_gradient_fun" shouldn't have to be something that one has to think about. It feels like a relic from the graph-based frameworks and doesn't fit nicely with the functional style of Jax. I also assume that by default, the root function would implement something like Christianson's two-phases algorithm https://doi.org/10.1080/10556789408805572

shoyer commented 4 years ago

I agree that stop_gradient_fun is really clunky.

It isn't needed for all solve functions, but I suspect it will be needed for most them. In particular, right now it's needed for everything that makes use of lax.while_loop internally, because lax.while_loop does not have a JVP rule defined.

You immediately get an exception when you call binary_search(f) inside a function being differentiated, if the function f depends on any parameters. For example:

def sqrt_cubed(x):
  f = lambda y: y ** 2 - x ** 3
  y = binary_search(f)  # this is where the error happens
  return lax.define_implicit_gradient(f, y, tangent_solve)

This is why @gehring's custom_root2 example wouldn't work in JAX today. An exception would be raised on the first line under automatic differentiation, just from calling solve(f, initial_guess):

def custom_root2(f, initial_guess, solve, tangent_solve):
  x = lax.stop_gradient(solve(f, initial_guess))
  return lax.define_implicit_gradient(f, x, tangent_solve)

My understanding is that JAX's auto-diff works by evaluating JVP rules for every primitive for which the inputs have non-zero gradients. We don't actually need derivatives of the output of solve(), but JAX doesn't know that.

Possible ways to work around this:

  1. Write our solve functions only with primitives which have defined JVP rules, i.e., don't use lax.while_loop. This could work OK if we always use JIT, but in eager mode would still involve a lot of redundant JVP calculations that eventually get thrown away.
  2. Use something like stop_gradient_fun everywhere, which is just magic for rewriting f = lambda y: y ** 2 - x ** 3 into f = lambda y: y ** 2 - lax.stop_gradient(x) ** 3.
  3. Refactor JAX's JVP interpreter to no longer raise errors as soon as a non-differentiable function is encountered, but rather when derivatives of the non-differentiable function are actually needed -- which should not be the case if they are replaced by define_implicit_gradient or stop_gradient.

I think we could make (3) work, though it may have some complications for error reporting. It also has the downside that if the solve happens to be fully or partially differentiable, we're doing to do some extra work computing JVPs unless computation for unused outputs is removed by jit.

gehring commented 4 years ago

Thanks for the explanation! I didn't realize that was how JAX handled auto-diff but given it is based on JVP ops, that makes sense.

I think we could make (3) work, though it may have some complications for error reporting. It also has the downside that if the solve happens to be fully or partially differentiable, we're doing to do some extra work computing JVPs unless computation for unused outputs is removed by jit.

I think that would be the best solution. One possible compromise might be to only do this when tracing, and keeping stop_gradient_fun for the pure eager alternative. Though this might be worth some careful thought since it would be moving the behavior of tracing further away from the eager mode. However, I don't think a compromise like that would make define_implicit_gradient + stop_gradient_fun more likely to confuse someone.

Or... we could settle for custom_root. This is what we currently have in our package for handling fixed points (though based on specifying a vjp solver). Though, for some reason define_implicit_gradient feels "purer" and more inline with JAX overall theme.

But more generally, issues of finite dev-hours aside, would it make sense (or even be possible) to change JAX's JVP handling to be completely lazy (maybe with shape inference and memory allocation still eager)? By that I mean do minimal JVP related work until the output of some JVP is needed by some non-JVP operation.

shoyer commented 4 years ago

I have a version of delayed errors for undefined JVPs working in https://github.com/google/jax/pull/1550 that removes the need for stop_gradient_fun in user code. We'll see if it's acceptable to the JAX gods!

One possible compromise might be to only do this when tracing, and keeping stop_gradient_fun for the pure eager alternative.

In JAX, "tracers" are used for all transformations, including both auto-diff and JIT, so your use of the word "tracing" to refer to JIT only (if I understand you correctly) confused me a little at first.

Currently JAX's auto-diff currently doesn't know anything about XLA/JIT, and I imagine we'd like to keep it that way.

But more generally, issues of finite dev-hours aside, would it make sense (or even be possible) to change JAX's JVP handling to be completely lazy (maybe with shape inference and memory allocation still eager)? By that I mean do minimal JVP related work until the output of some JVP is needed by some non-JVP operation

This is a great question! I'll leave it to the experts like @mattjj and @dougalm

shoyer commented 4 years ago

I also assume that by default, the root function would implement something like Christianson's two-phases algorithm https://doi.org/10.1080/10556789408805572

Actually, custom_root/define_implicit_gradient is even simpler than that. It's really just a restatement of the implicit function theorem in terms of auto-diff transformations. That's why you have to tell it how to compute tangent_solve.

We could eventually add a default implementation of tangent_solve, but I would also be happy to leave that for higher level functions like scipy.optimize.root. I guess the sensible default choices would be iterative Krylov methods like restarted GMRES or CG, but these really require preconditioning to work well.

You could write Christian's two phases algorithm as a specific implementation of tangent_solve, but it requires specialized knowledge about the form of f in define_implicit_gradients. Specifically, you have access to a function g such that f(x) = g(x) - x = 0. So it probably want makes sense to write another function implementing this specific gradient strategy, e.g., scipy.optimize.fixed_point.

gehring commented 4 years ago

In JAX, "tracers" are used for all transformations, including both auto-diff and JIT, so your use of the word "tracing" to refer to JIT only (if I understand you correctly) confused me a little at first.

Yes, that was my understanding of JAX's auto-diff but I had assumed I must have misunderstood something. My confusion came from trying to understand why these unused JVP operations could not be pruned when tracing regardless of jit. My other question sort of asked the same thing so I'm happy to wait and see what the experts have to say!

shoyer commented 4 years ago

But more generally, issues of finite dev-hours aside, would it make sense (or even be possible) to change JAX's JVP handling to be completely lazy (maybe with shape inference and memory allocation still eager)? By that I mean do minimal JVP related work until the output of some JVP is needed by some non-JVP operation.

I've been learning more about how JAX's transformations work.

Transformations like forward-mode automatic differentiation, batching and JIT compilation work by overloading normal Python evaluation, with arrays replaced by symbolic tracers. For example, in forward-mode auto-diff, every variable is replaced by a tracer that keeps track of both original variable (the "primal") and its derivative (the "tangent"). There's no graph, just normal Python code.

This is really nice when it comes to debugging, because it means you can really sensible error messages and tracebacks from Python, pointing back exactly to the line where things went wrong. For example, if you try to differentiate a function using while_loop, you may get a long traceback, but somewhere in there it will point to a line of code that you wrote. You can even drop into a debugger.

In contrast, backwards mode auto-diff requires evaluating an abstract computation graph. If something breaks, the Python traceback will not be very interpretable, because it will point to an evaluation deep inside JAX's auto-diff machinery not your original Python function.

One reason why JAX is such a pleasure to use that such abstract graph evaluation is only done to the bare minimum extent possible. Backwards mode auto-diff is composed into a series of simpler transformation, namely forward mode auto-diff followed by transposition. Only the later part requires evaluating a graph rather than normal Python code.

Making JVP evaluation lazy would negate this advantage. Instead of evaluating normal Python code for JVP rules, we'd be evaluating a stored computation graph.

We run into a similar issue if we try to make errors from missing JVP rules lazy. Now legitimately missing JVP rules result in error messages that no longer point back to user code, so it's no longer obvious where things went wrong.

For these reasons, I think it's a non-starter to change JAX's JVP evaluation to be lazy, even for error reporting. It would be nice for this niche use-cases, but would make debugging harder for everything else.

gehring commented 4 years ago

@shoyer Thanks once more for the detailed explanation! I suspected that the answer would be something like that. Being so accustomed to the VJP/backprop approach to auto-diff, it is easy to forget the drawbacks we're try to avoid!

Continuing the define_implicit_gradient vs. custom_root discussion, I think the need for stop_gradient_fun puts both options on equal footing. In use cases where JVPs are what you want, I wouldn't expect stop_gradient_fun to feel too awkward. I think this would be confusing only for a user who is only interested in VJPs.

Now, I think I would put my vote on define_implicit_gradient if stop_gradient_fun can be avoided when only using VJPs. Would it make sense to have lazy errors only in the VJP case (maybe with some context manager around linearize)? If this could be made possible without making tracebacks worse as you described, then I think it would make define_implicit_gradient the best option:

def sqrt_cubed_for_vjp(x):
  f = lambda y: y ** 2 - x ** 3
  y = binary_search(f)
  y = lax.stop_gradient(y)  # sufficient for vjp?
  return lax.define_implicit_gradient(f, y, tangent_solve)

def sqrt_cubed_for_jvp(x):
  f = lambda y: y ** 2 - x ** 3
  y = binary_search(lax.stop_gradient_fun(f))
  return lax.define_implicit_gradient(f, y, tangent_solve)

jax.grad(sqrt_cubed_for_vjp)(2.)  # nice if this works
jax.jvp(sqrt_cubed_for_vjp, 2., 1.)  # fine if this fail 

jax.grad(sqrt_cubed_for_jvp)(2.)  # works as expected
jax.jvp(sqrt_cubed_for_jvp, 2., 1.)  # works as expected

If this can't be done without heavily modifying the AD backend or without worsening the debugging experience, then I think custom_root is the best compromise. WDYT?

shoyer commented 4 years ago

Would it make sense to have lazy errors only in the VJP case (maybe with some context manager around linearize)?

We could make errors errors lazy when calculating linearize inside VJP only, but it would still make for worse error reporting worse for VJPs. The "forward pass" of linearize still uses normal Python evaluation, while building up an abstract computation graph (which gets traversed inside backwads_pass).

We can see this by defining a dummy primitive without a JVP rule:

import jax
import traceback

def identity(x):
  return identity_p.bind(x)

identity_p = jax.core.Primitive('identity')
identity_p.def_impl(lambda x: x)

def uses_identity(x):
  return identity(x)

try:
  jax.grad(uses_identity)(1.0)
except Exception as e:
  assert 'uses_identity' in traceback.format_exc()
  print(e)  # Forward-mode differentiation rule for 'identity' not implemented

This would also be the case for primitives like while_loop

If this can't be done without heavily modifying the AD backend or without worsening the debugging experience, then I think custom_root is the best compromise. WDYT?

I agree. custom_root is a little clunky, but it is harder to misuse.

gehring commented 4 years ago

I agree. custom_root is a little clunky, but it is harder to misuse.

Sounds like we are in agreement on all points. I'm happy with the discussion and I can't think of any new points to bring up. If we can't get reasonably get define_implicit_gradient to work with vjps without needing stop_gradient_fun (which, as I understand it, is not possible/worth it), then I'm happy with preferring custom_root.

Looking forward to seeing this PR merged! This is quality work :)