Hi,
Here is a simple example in which check_grads produces an error :
from autograd.extend import primitive, check_grads, defvjp, defjvp
@primitive
def f(x):
return x**2
def f_jvp_custom(ans, x):
def Jf_x(d):
return 2*x*d
return Jf_x
def f_vjp_custom(ans, x):
def Jf_x(d):
return 2*x*d
return Jf_x
defvjp(f,f_vjp_custom)
defjvp(f,f_jvp_custom)
check_grads(f, modes=['rev'], order=1)(3.0) # This works
check_grads(f, modes=['fwd'], order=1)(3.0) # Not working
check_grads(f, order=1)(3.0) # Not working
check_grads(f)(3.0) # Not working
It seems that there is an issue with the forward mode (check_grads seems to call both modes by default). The error I get is the following : TypeError: f_jvp_custom() takes 2 positional arguments but 3 were given
Also, if we check the jvp/vjp directly, it suggests that the problem comes from check_jvp :
from autograd.test_util import check_vjp, check_jvp
check_vjp(f, 3.0) # This works
check_jvp(f, 3.0) # Not working, same error message
So I think that I am doing something wrong here, but I do not see what?
Hi, Here is a simple example in which check_grads produces an error :
It seems that there is an issue with the forward mode (check_grads seems to call both modes by default). The error I get is the following :
TypeError: f_jvp_custom() takes 2 positional arguments but 3 were given
Also, if we check the jvp/vjp directly, it suggests that the problem comes fromcheck_jvp
:So I think that I am doing something wrong here, but I do not see what?