HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
7.01k stars 912 forks source link

Nth derivative #68

Closed r0fls closed 9 years ago

r0fls commented 9 years ago

I thought the example looked a little clunky, where we do the following to compute an nth derivative:

>>> grad_tanh_2 = grad(grad_tanh)           # 2nd derivative
>>> grad_tanh_3 = grad(grad_tanh_2)         # 3rd derivative1
>>> grad_tanh_4 = grad(grad_tanh_3)         # etc.
>>> grad_tanh_5 = grad(grad_tanh_4)
>>> grad_tanh_6 = grad(grad_tanh_5)

Would it be worth having an optional second argument in grad, say:

 >>> grad_tanh_6 = grad(tanh,6)

Or just a different function?

duvenaud commented 9 years ago

Thanks for the suggestion! In our experience it's pretty rare that people actually want plain higher derivatives - usually they'll want something more like Hessian-vector products.

Whenever we've been tempted to make the interface to grad() more complicated, instead we've added a helper function here: https://github.com/HIPS/autograd/blob/master/autograd/convenience_wrappers.py

Now that I look at it again, that example should maybe use elementwise_grad it would be a lot faster at least.

mattjj commented 9 years ago

That exact spec wouldn't work because grad already takes a second argument, which is argnum (default argnum=0).

I agree that it looks clunky, but it does communicate that you can just keep calling grad on the output of grad and there's no special bookkeeping happening.

mattjj commented 9 years ago

Just for fun, here are two implementations of a higher-order function repeated that one could wrap around grad here:

def repeated(f, n):
    def fn(x):
        for i in range(n):
            x = f(x)
        return x
    return fn

def repeated(f, n):
    def helper(n, x):
        return helper(n-1, f(x)) if n > 0 else x
    return lambda x: helper(n, x)

So one could write repeated(grad, 10)(np.tanh)(x) or something. (In addition, David's suggestion to use elementwise_grad for broadcasting makes a lot of sense!)