FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

A simple custom gradient with @adjoint for a Base function #1088

Open SimonCoste opened 3 years ago

SimonCoste commented 3 years ago

Hi, Julia's base round function takes a number and outputs the nearest integer. I want to customly define its gradient as the identity. I used Zygote's @adjoint macro very simply,

@adjoint round(x)=round(x), y->(y,)

but as a result, gradient(round, 0) outputs me nothing and I can't compute gradients of compositions of functions and round.

Is it possible to directly redefine adjoints of base functions ?

On the other hand, if I define a function g(x) = round(x) and then I use the adjoint macro,

@adjoint g(x)=g(x), y->(y,)

then gradient(g, 0) works well, as well as the whole backprop machinery. But only for composition of real functions, eg it correctly computes the gradient of x -> g(x)^2. It does not work when the gradient is applied entry-wise: for instance,

F(x) = sum(g.(x))
gradient(F, [0,0])

still outputs a nothing, but it should easily differentiate this. Do I have to redefine the gradient for g. also ?

mcabbott commented 3 years ago

I believe that with round, the definition here seems to take priority, and making it more specific runs into other problems:

https://github.com/FluxML/Zygote.jl/blob/2d5edf44ad7191fcd5d5816d8ab161dff8e14765/src/lib/number.jl#L2

julia> Zygote.@adjoint round(x::Float64)=round(x), y->(y,)

julia> gradient(round, 1.2)
ERROR: MethodError: _pullback(::Zygote.Context, ::typeof(round), ::Float64) is ambiguous. Candidates:
  _pullback(__context__::ZygoteRules.AContext, var"310"::typeof(round), x::Float64) in Main at /Users/me/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:56
  _pullback(::Zygote.Context, ::typeof(round), args...) in Zygote at /Users/me/.julia/dev/Zygote/src/lib/grad.jl:8

For broadcasting, what's happening is that for reasons of speed, it prefers to use ForwardDiff when possible. (Using Zygote's own gradients often causes inference problems, and factor of 100 slowdowns.)

julia> g(x) = @show(x)^3
g (generic function with 1 method)

julia> Zygote.@adjoint g(x)=g(x), y->(@show(y),)

julia> gradient(g, 1.2)
x = 1.2
y = 1.0
(1.0,)

julia> gradient(x -> sum(g.(x)), [1.2, 3.4])
x = Dual{Nothing}(1.2,1.0)
x = Dual{Nothing}(3.4,1.0)
([4.32, 34.67999999999999],)

If you want to overload this, you should probably define a method for dual numbers. Something like g(x::Dual{Z}) where {Z} = Dual{Z}(g(value(x)), map(??, partials(g))). Although usually these work through most things, or will complain loudly when they fail.