Closed barak closed 9 years ago
Does it make sense to take a derivative with respect to an integer-valued argument? I could treat it as a float, but then the derivatives would be incorrect if the function contained integer-specific behavior like integer division. I also like the idea that the type of the derivative (float, array, list etc) is always the same as the type of the variable with respect to which the function is differentiated.
One perspective is that you should be able to take the derivative of any (numeric) function. There's no barrier in principle to integers, e.g., you can consider polynomials over the integers with integer coefficients, these are nicely closed under differentiation.
In Math, and Scheme, you'd regard the integers as a subset of the reals so this isn't an issue. But in Python, I dunno, I can see it both ways.
If you take a differential geometry perspective, where reverse mode is a pull-back and thus defined by a dual construction over the push-forward, then even if the primal quantities are integers the tangent and co-tangent spaces are vector spaces over the reals. So derivatives would be float even when primals are real.
Anyway, a second perspective is that when f(2.0) and f(2) are "the same" then their derivatives should both be defined, and both be "the same" in the same sense of sameness.
Ok, I can see both sides. The problem is that Python treats integers and floats differently, so that f(2)
and f(2.0)
are not the same in general (consider f = lambda x : x / 4
). What should the result of grad(lambda x : x /4)(2)
be? 0? 0.25?
I'd say 0, but there are two different 0s it might be. One being the real number 0. The other being the zero element of a zero-dimensional vector space, which is of course the only element of that vector space, aka ().
I guess the bottom line is the Python is a mess? Because this wouldn't be an issue in Scheme or Haskell.
I think @dougalm fixed this in 1d5375d. The above example seems to work!
In [2]: grad(lambda x:x*x)(2)
autograd/core.py:209: UserWarning: Casting int to float to handle differentiation.
warnings.warn("Casting int to float to handle differentiation.")
Out[2]: 4.0