HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
6.97k stars 909 forks source link

check_grads (forward mode) fails on a simple custom function #550

Open Guillaume-Garrigos opened 4 years ago

Guillaume-Garrigos commented 4 years ago

Hi, Here is a simple example in which check_grads produces an error :

from autograd.extend import primitive, check_grads, defvjp, defjvp

@primitive
def f(x):
    return x**2

def f_jvp_custom(ans, x): 
    def Jf_x(d): 
        return 2*x*d
    return Jf_x

def f_vjp_custom(ans, x): 
    def Jf_x(d):
        return 2*x*d
    return Jf_x

defvjp(f,f_vjp_custom)
defjvp(f,f_jvp_custom)

check_grads(f, modes=['rev'], order=1)(3.0) # This works
check_grads(f, modes=['fwd'], order=1)(3.0) # Not working
check_grads(f, order=1)(3.0) # Not working
check_grads(f)(3.0) # Not working

It seems that there is an issue with the forward mode (check_grads seems to call both modes by default). The error I get is the following : TypeError: f_jvp_custom() takes 2 positional arguments but 3 were given Also, if we check the jvp/vjp directly, it suggests that the problem comes from check_jvp :

from autograd.test_util import check_vjp, check_jvp

check_vjp(f, 3.0) # This works
check_jvp(f, 3.0) # Not working, same error message

So I think that I am doing something wrong here, but I do not see what?

Gattocrucco commented 4 years ago

The jvps are defined differently, try this:

def f_jvp_custom(d, ans, x): 
     return 2*x*d