FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Unsupported Control Flow with custom adjoint #41

Closed willtebbutt closed 5 years ago

willtebbutt commented 5 years ago

I'm looking at implementing a couple of custom adjoints. With this MWE:

using Zygote, Random
using Zygote: @adjoint
import LinearAlgebra: \

@adjoint function \(A::AbstractMatrix, B::AbstractVector)
    Y = A \ B
    return Y, function(Ȳ)
        B̄ = A' \ Ȳ
        return (-B̄ * Y', B̄)
    end
end

rng, P, Q = MersenneTwister(123456), 10, 9
X, Y = randn(rng, P, P), randn(rng, P, Q)

f = (X, Y)->sum(X \ Y)
Zygote.gradient(f, X, Y)

I get this error:

ERROR: Compiling Tuple{typeof(\),Array{Float64,2},Array{Float64,2}}: Unsupported control flow
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] merge_returns(::Core.Compiler.IRCode) at /home/wct23/.julia/dev/Zygote/src/compiler/reverse.jl:29
 [3] #Primal#39(::Nothing, ::Type, ::Core.Compiler.IRCode) at /home/wct23/.julia/dev/Zygote/src/compiler/reverse.jl:193
 [4] Type at ./none:0 [inlined]
 [5] #Adjoint#65 at /home/wct23/.julia/dev/Zygote/src/compiler/reverse.jl:392 [inlined]
 [6] (::getfield(Core, Symbol("#kw#Type")))(::NamedTuple{(:varargs,),Tuple{Nothing}}, ::Type{Zygote.Adjoint}, ::Core.Compiler.IRCode) at ./none:0
 [7] _lookup_grad(::Type) at /home/wct23/.julia/dev/Zygote/src/compiler/emit.jl:121
 [8] #s54#851 at /home/wct23/.julia/dev/Zygote/src/compiler/interface2.jl:17 [inlined]
 [9] #s54#851(::Any, ::Any, ::Any) at ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:506
 [11] #7 at ./REPL[7]:1 [inlined]
 [12] (::Zygote.J{Tuple{getfield(Main, Symbol("##7#8")),Array{Float64,2},Array{Float64,2}},Tuple{getfield(Main, Symbol("##7#8")),Array{Float64,2},Array{Float64,2},getfield(Zygote, Symbol("##1884#back#734")){getfield(Zygote, Symbol("##730#732")){Array{Float64,2}}},Zygote.J{Tuple{typeof(\),Array{Float64,2},Array{Float64,2}},Tuple{typeof(\)}}}})(::Int8) at /home/wct23/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [13] (::getfield(Zygote, Symbol("##66#67")){Zygote.J{Tuple{getfield(Main, Symbol("##7#8")),Array{Float64,2},Array{Float64,2}},Tuple{getfield(Main, Symbol("##7#8")),Array{Float64,2},Array{Float64,2},getfield(Zygote, Symbol("##1884#back#734")){getfield(Zygote, Symbol("##730#732")){Array{Float64,2}}},Zygote.J{Tuple{typeof(\),Array{Float64,2},Array{Float64,2}},Tuple{typeof(\)}}}}})(::Int8) at /home/wct23/.julia/dev/Zygote/src/compiler/interface.jl:38
 [14] gradient(::Function, ::Array{Float64,2}, ::Vararg{Array{Float64,2},N} where N) at /home/wct23/.julia/dev/Zygote/src/compiler/interface.jl:44
 [15] top-level scope at none:0

Any thoughts? Am I doing something wrong, or is this an issue on Zygote's end?

(I'm running on Zygote and IRTools master, and Julia 1.0.2)

jekbradbury commented 5 years ago

There's at least one control flow-related Zygote issue that's fixed in Julia master, I think, so that might be worth trying?

MikeInnes commented 5 years ago

This is obviously still a valid bug, but note that you aren't hitting your custom adjoint here (you seem to be passing matrix x matrix into \).

willtebbutt commented 5 years ago

This is obviously still a valid bug, but note that you aren't hitting your custom adjoint here (you seem to be passing matrix x matrix into ).

I'm an idiot. Thanks for catching this. Works fine now.

MikeInnes commented 5 years ago

No worries, if Zygote's errors don't make the issue obvious then that's definitely still its problem.