jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

Conditional derivatives should raise a warning or exception #676

Closed proteneer closed 2 months ago

proteneer commented 5 years ago

This was really tricky to track down and repro - the bug is manifested when you need to differentiate through a conditional itself (as opposed to its branches). I may be using a different terminology from you folks, so I think some code is most illustrative:

import numpy as onp
import jax
import jax.numpy as np

def foo(x):
    if onp.random.rand() < np.cos(x):
        return 2
    else:
        return 3

def expectation(val):
    total = 0
    counts = 50000
    for _ in range(counts):
        total += foo(val)
    return total/counts

# we can also compute the expectation
# analytically as:

#   P(true)*2 + P(false)*3
# = cos(x)*2 + (1-cos(x))*3

x = expectation(0.5) 
print(x)
# for large counts, this returns exactly 
# 2.12241743811

# the gradient of the expectation should be:
# -sin(x)*2 + sin(x)*3
# 0.4794255386

grad_expectation = jax.jacfwd(expectation, argnums=(0,))
# always returns zero because foo returns constants
print("autodiff", grad_expectation(0.5)) 
proteneer commented 5 years ago

Also: if you guys have an actual solution to do this in a general way, I'd be happy with not raising an exception or warning either :)

mattjj commented 5 years ago

Thanks for the brain-teaser!

I think there might be a math issue here actually; in particular, I don't think this function meets the usual sufficient conditions for Leibniz's integral rule to apply, but the consequent of that rule is essentially equivalent to the conclusion you want to draw.

Let's rewrite the function to more explicitly depend on the random element:

def f(x, w):
  if w < np.cos(x):
    return 2
  else:
    return 3

Now taking it as a function over R \times (0, 1), it isn't continuous, and the partial derivative with respect to x doesn't exist everywhere. In particular, that violates condition 2 (and 3) of the measure-theoretic statement of the rule because for every w the derivative fails to exist for all x. (In some sense, even if we said it exists but takes on an infinitely large value to make the function value jump from 2 to 3, the third condition would fail because it requires an integrable bound on the magnitude of the derivative that is independent of x.)

The gradient of the expectation is really the left-hand side of the Leibniz rule (differentiation outside the integral), but running a Monte Carlo estimate using this program is estimating the right-hand side of the rule (integration on the outside).

What do you think? Do you buy that this is math's fault and not computer science's fault?

mattjj commented 5 years ago

Ah but maybe you're asking "can an autodiff system detect this kind of issue?" I'm not sure... at the very least I'd guess we would need to write the function in a friendly way, i.e. we'd need to write it in the f(x, w) form (where w may just be the PRNGKey in general), then perhaps look at the function that results when we hold w sufficiently abstract (in some neighborhood of the value of x we care about). If that thing is free of pathologies then differentiating under the integral sign would be safe... That sounds like a fun research project :)

proteneer commented 5 years ago
  1. I agree that this is particular pathological implementation of the Leibniz integral rule and the various mathematical pathologies associated with it. Unfortunately in our field (and many other others), the analytic integral is completely intractable so we need to resort to Monte Carlo methods (i.e. the RHS).

  2. The current state is that the return gradient is both incorrect and non-sensical, and I sort of view this as a silent-error type of bug in that it took quite a while to debug. There's nothing useful one can do with the returned gradient as-is, so I'd rather it throw an exception. I think you guys understand the internals of jax better than I do to know if it can even detect this kind of a bug (other systems like tensorflow will return None as opposed to 0 for gradients since it cannot trace a path to the input variable.)

  3. I'd be very much interested in finding ways to differentiate under the integral sign - as it would essentially allow one to optimize through a lot of really cool Monte Carlo models even beyond our field. The analytic derivative for the RHS of a uniform sampler (used by the exceedingly large majority of MonteCarlo models) would actually be:


def foo(x):
    if onp.random.rand() < np.cos(x):
        return 2
    else:
        return 3

def analytic_foo_grad(x):
        return -np.sin(x)

If it were even remotely to automatically convert to analytic_foo_grad then we'd be in business (at least for the uniform case).

  1. I agree this is borderline a research project - nevertheless a pretty cool one
mattjj commented 5 years ago

I mostly disagree with bullet 2; JAX is returning a sensible derivative. That is, $\partial [ x \mapsto f(x, w) ] = 0$ for almost all $x$ and $w$ (i.e. everywhere except when $w = cos(x)$), where $f(x, w) = 2 I[w < cos(x)] + 3 I[w \geq cos(x)]$. If you disagree, what do you think the derivative of that function is? (Notice there's no integration here.)

When $w = cos(x)$ no Frechet derivative exists, and it would be reasonable to raise an error there, but (1) that's not affecting the integration (we can do surgery on a set of measure zero) so I'm assuming that's not the error you're talking about and (2) JAX gives you one of the directional derivatives (the one from the right).

In other words, nothing's going wrong with the Monte Carlo estimate of the integral of the derivative; the integral of the derivative is zero for all x, as the Monte Carlo approximation reports (with zero variance, in fact!). The problem is that what you actually want is the derivative of the integral, but the derivative of the integral is not equal to the integral of the derivative in general.

If you want to form a valid Monte Carlo approximation to the derivative of the integral, one way to do it would be to find an integral representation that satisfies the conditions of the Leibniz rule. In the recent ML literature we tend to call those "reparameterization gradients" and there's a bit of a cottage industry in developing them for common expectations. I'm sure there's a much longer history, with different terminology, in physics.

mattjj commented 5 years ago

To focus the discussion, can we come up with any concrete action items for JAX here? If not it might be that trying to reason about integration is out of scope.

mattjj commented 5 years ago

Here's an attempt at a sharper example:

Take the function $f : (0, 1)^2 \to R$ defined by $f(x, y) = I[y < x]$ which takes value 1 when y < x and 0 otherwise, and consider also the function $F : (0, 1) \to R$ where $F(x) = \int_0^1 f(x, y) dy = x$. I think these are true:

  1. $d/dx F(x) = 1$,
  2. $f_x(x, y) = 0$ almost everywhere as a function of y,
  3. if $g : (0, 1) \to R$ is zero almost everywhere then $\int_0^1 g(y) dy = 0$.

The second item is why I think JAX is returning a sensible derivative, and the first compared to the third is a model for how the estimator is breaking down.

mattjj commented 5 years ago

I keep coming back to this issue :) Really it is a great brain teaser, so thanks for bringing it up!

Maybe it would be a cool programming languages / machine learning problem to automatically generate reparameterization estimators (or report failure) for a probabilistic program. It would have to handle cases like the ones we're discussing in this thread. Now that I say that I suspect there is probably already a literature on doing this in the probabilistic programming community and I'm just not familiar with it (or not remembering it because I've spent too long away from machine learning!). @fehiepsi do you know of work on automatically forming reparameterization-style Monte Carlo estimators for derivatives of integral representations?

(I did some work on another kind of automatic approximate integration recently, and someday I'm going to revive that in JAX!)

proteneer commented 5 years ago

Thank you for taking the time to make such a detailed response concomitant with mathematical rigor. I completely agree with your analysis regarding the non-commutativity of the expectation and derivative operators if measure-theoretic assumptions fail. Now I'm going to put on my physicist hat and you're going to hate me for it :)

1) As you've noted, the Frechet derivative does not exist at w=cos(x). But observe that if we were to "simply" differentiate f(w,x) w.r.t. x and chain rule through the Heaviside function accordingly (assuming distributional calculus resulting in the use of kronecker deltas), you'd arrive at:

$df(w,x)/dx = 2 \delta(w-cos(x))sin(x) - 3 \delta(w-cos(x))sin(x) = -\delta(w-cos(x))*sin(x)$

2) Numerical issues aside and if you believe in black magic, the above is actually identical to the derivative of the expectation (the left-hand side) when evaluated at w=cos(x). I can directly use the above equation to get the re-parameterized gradient for free simply by substituting w=cos(x) as opposed to a random number.

Also note that the cos(x) here is completely arbitrary, it can be any arbitrary function g(x). I think the above tricks works as long we're sampling from a uniform distribution. I've used this for a bunch of other things in the past. So basically, if we were to implement a distributional np.heaviside whose derivative is a kronecker delta (as opposed to using if/else statements), I can probably abuse the machinery to get reparameterization gradients for free.

mattjj commented 5 years ago

Distributional derivatives are cheating :)

You're right that maybe we can phrase this challenge as being about automatic distributional derivatives, and how those ultimately interact with integration. And you're also right that to catch this kind of thing in Python we'll probably need a special primitive like heaviside rather than writing it in terms of raw Python control flow. There are a lot of details to work out, but it's pretty interesting.

proteneer commented 5 years ago

That is a fantastic paper, thanks for linking it.

I also think we're roughly in agreement in that I'm obviously biased by being in the distributional camp - it definitely makes sense for to think on this for a little bit before implementing anything drastic. We can rename the issue to something more appropriate, like distributional derivatives / heaviside or something you feel appropriate.

There are so many cool applications (eg. Metropolis-based Monte Carlo algorithms, replica-exchange) that one can in principle do with this. Originally, I was intending to implement a differentiable MonteCarlo pressure-coupling method (a barostat) but I can look for a non-MonteCarlo based one for the time being.

PS: not cheating, just non-rigorous

fehiepsi commented 5 years ago

do you know of work on automatically forming reparameterization-style Monte Carlo estimators for derivatives of integral representations?

I am not familiar with this line of work. The above discussion is very interesting! I'll set out some time this weekend to follow up. :)

mattjj commented 2 months ago

I'm going to close this issue as it's the least-recently-updated still-open issue on our tracker, and I think the discussion was pretty thorough!