FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Custom Adjoint when Taking Gradient of Gradient (2nd Gradient) #745

Closed xrq-phys closed 4 years ago

xrq-phys commented 4 years ago

Hi,

I'm trying to use Zygote.jl up to 2nd order with some external matrix routines, but met a problem when trying to recursively(?) utilize custom gradients.

Here is an example to reproduce:

using LinearAlgebra
using Zygote: pullback, @adjoint

mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = a * b

@adjoint mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = begin
    @show label
    mymul(a, b, label), c̄ -> begin # <Subblock-1
        @show label
        ā = mymul(c̄', b, "ā")
        b̄ = mymul(a, c̄', "b̄")
        (ā, b̄)
    end # Subblock-1>
end

I suppose that @adjoint block here defines pullback also for mymul invocations inside the pullback method itself. But when I try to evaluate a second-order pullback as:

Ap = [1.0 2.0;
      3.0 4.0];
Bp = [1.0 2.0;
      1.0 2.0];
c1, dc1 = pullback(x -> mymul(Ap .* x^2, Bp, "c")[1, 1], 1.0)
c2, dc2 = pullback(dc1, 1.0)

dc2(1.0) #<< When trying to evaluate this expression

the program runs into a Mutating arrays is not supported error, which implies that Zygote.jl is still digging into definition of mymul when trying to differentiate 1st-order pullback dc1. (Differentiation of dc1 is expected to yield the already-defined @adjoint mymul again.)

I'm wondering if this is an expected behaviour on Zygote.jl's side or it could be a bug? and in any case is there a workaround for this?

PS: REPL output of code above is:

# c1, dc1 = ...
label = "c"
(3.0, Zygote.var"#43#44"{typeof(∂(#9))}(∂(#9)))

# c2, dc2 = ...
label = "c"
label = "ā"
label = "b̄"
((10.0,), Zygote.var"#43#44"{typeof(∂(λ))}(∂(λ)))

# dc2(1)
label = "c"
label = "ā"
ERROR: Mutating arrays is not supported

Hope I didn't make mistakes in writing @adjoint blocks.

xrq-phys commented 4 years ago

Update:

After adding more hooks/intermediate outputs I'm getting:

id(A) = A
@adjoint id(Abarrier) = Abarrier, Ābarrier -> begin
    @show Abarrier
    @show Ābarrier
    (Ābarrier, )
end

mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = a * b

@adjoint mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = begin
    @show label
    mymul(a, b, label), c̄ -> begin
        println(label, " 1st diff begin")
        c̄ = id(c̄)
        ā = mymul(c̄', b, "ā")
        b̄ = mymul(a, c̄', "b̄")
        println(label, " 1st diff done")
        (ā, b̄)
    end
end
julia> c1, dc1 = pullback(x -> mymul([1.0 2.0; 3.0 4.0].* x, [1.0 2.0; 1.0 2.0] .* x, "c")[1, 1], 1.0)
label = "c"
(3.0, Zygote.var"#43#44"{typeof(∂(#11))}(∂(#11)))

julia> c2, dc2 = pullback(dc1, 1.0)
c 1st diff begin
label = "ā"
label = "b̄"
c 1st diff done
((8.0,), Zygote.var"#43#44"{typeof(∂(λ))}(∂(λ)))

julia> dc2(1.0)
b̄ 1st diff begin
b̄ 1st diff done
ā 1st diff begin
ā 1st diff done
Abarrier = [1.0 0.0; 0.0 0.0]
Ābarrier = [9.0 17.0; 13.0 23.0]
ERROR: Mutating arrays is not supported

So it seems that Zygote is utilizing adjoints recursively as expected, but some implicit accumulation when evaluating dc1 becomes reason for its breaking down. Sorry for the obfuscation.

DhairyaLGandhi commented 4 years ago

Could you try https://github.com/DhairyaLGandhi/Zygote.jl/tree/dg/iddict

I think I set it up as

gradient(...) do ...
  y, back = pullback(...)
  back(y)
end
xrq-phys commented 4 years ago

It worked after I changed:

pullback(x -> mymul(Ap .* x^2, Bp, "c")[1, 1], 1.0)

to:

pullback(x -> sum(mymul(Ap .* x^2, Bp, "c").*[ 1 0; 0 0 ]), 1.0)

Taking index's causing mutation seems reasonable and patch by @DhairyaLGandhi seems to have fixed another rule that prevents case 2 from being executed in current master. My initial descriptions didn't catch the point.

Thanks! Please allow me to check a little more before closing this issue.