Closed xrq-phys closed 4 years ago
Update:
After adding more hooks/intermediate outputs I'm getting:
id(A) = A
@adjoint id(Abarrier) = Abarrier, Ābarrier -> begin
@show Abarrier
@show Ābarrier
(Ābarrier, )
end
mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = a * b
@adjoint mymul(a::Union{Matrix, Adjoint}, b::Union{Matrix, Adjoint}, label::String) = begin
@show label
mymul(a, b, label), c̄ -> begin
println(label, " 1st diff begin")
c̄ = id(c̄)
ā = mymul(c̄', b, "ā")
b̄ = mymul(a, c̄', "b̄")
println(label, " 1st diff done")
(ā, b̄)
end
end
julia> c1, dc1 = pullback(x -> mymul([1.0 2.0; 3.0 4.0].* x, [1.0 2.0; 1.0 2.0] .* x, "c")[1, 1], 1.0)
label = "c"
(3.0, Zygote.var"#43#44"{typeof(∂(#11))}(∂(#11)))
julia> c2, dc2 = pullback(dc1, 1.0)
c 1st diff begin
label = "ā"
label = "b̄"
c 1st diff done
((8.0,), Zygote.var"#43#44"{typeof(∂(λ))}(∂(λ)))
julia> dc2(1.0)
b̄ 1st diff begin
b̄ 1st diff done
ā 1st diff begin
ā 1st diff done
Abarrier = [1.0 0.0; 0.0 0.0]
Ābarrier = [9.0 17.0; 13.0 23.0]
ERROR: Mutating arrays is not supported
So it seems that Zygote is utilizing adjoints recursively as expected, but some implicit accumulation when evaluating dc1
becomes reason for its breaking down. Sorry for the obfuscation.
Could you try https://github.com/DhairyaLGandhi/Zygote.jl/tree/dg/iddict
I think I set it up as
gradient(...) do ...
y, back = pullback(...)
back(y)
end
It worked after I changed:
pullback(x -> mymul(Ap .* x^2, Bp, "c")[1, 1], 1.0)
to:
pullback(x -> sum(mymul(Ap .* x^2, Bp, "c").*[ 1 0; 0 0 ]), 1.0)
Taking index's causing mutation seems reasonable and patch by @DhairyaLGandhi seems to have fixed another rule that prevents case 2 from being executed in current master
. My initial descriptions didn't catch the point.
Thanks! Please allow me to check a little more before closing this issue.
Hi,
I'm trying to use Zygote.jl up to 2nd order with some external matrix routines, but met a problem when trying to recursively(?) utilize custom gradients.
Here is an example to reproduce:
I suppose that
@adjoint
block here defines pullback also formymul
invocations inside the pullback method itself. But when I try to evaluate a second-order pullback as:the program runs into a
Mutating arrays is not supported
error, which implies that Zygote.jl is still digging into definition ofmymul
when trying to differentiate 1st-order pullbackdc1
. (Differentiation ofdc1
is expected to yield the already-defined@adjoint mymul
again.)I'm wondering if this is an expected behaviour on Zygote.jl's side or it could be a bug? and in any case is there a workaround for this?
PS: REPL output of code above is:
Hope I didn't make mistakes in writing
@adjoint
blocks.