Closed petvana closed 4 years ago
This is almost right, but there is one problem:
function finitediff(f, x)
ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x))
return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ)
end
julia> finitediff(z -> mod2pi(z), 2pi)
-82569.1859259104
But your function will return 1.
You need to check if your argument is an integer multiple of 2pi
and if it is, then you return NaN
. Have a look at this line for an example:
is that gonna be:
@define_diffrule Base.mod2pi(x, r) = :( first(promote(ifelse(isinteger($x / 2pi), NaN, 1), NaN)) ), :( z = $x / 2pi; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) )
I have a problem testing this implementation though, how can I tell the Zygote
to the local DiffRules
instead of its dep
?
I don't understand the code that you have posted. mod2pi
has only one argument, but you made a diff rule for a two argument function. You need to make sure that you have a function with only one argument and you only need one symbol on the right hand side.
As for testing the function, just make sure that the DiffRules tests pass and make sure that you return NaN
for the gradient when the user passes an argument that is a multiple of 2pi. Once that is sorted then you should think about making sure that things work as expected on the Zygote side before closing the Zygote issue.
If I'm not mistaken then this is what we want to do? https://github.com/JuliaDiff/DiffRules.jl/compare/master...Moelf:master
As for the test, no way the function should return what finitediff
gives right?
Yes, I believe that your master branch looks good now. The only thing that I notice is that your code on line 61 is not perfectly aligned.
As for the test, no way the function should return what finitediff gives right?
When the difference between x
and all multiples of 2pi is greater than ϵ
, then the rule that you provide and finitediff
will return the same thing. However, if x
is equal to a multiple of 2pi then finitediff
returns garbage but your rule should return NaN
. Looking at your code and the test that you wrote, you have the above behavior correctly implemented.
Fix issue #118 from Zygote. It seems to be better adding the rule here instead of the end library.
https://github.com/FluxML/Zygote.jl/issues/118