FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

Reverse on reverse fails with a Flux nn #1266

Open YichengDWu opened 2 years ago

YichengDWu commented 2 years ago
using Flux, Zygote

model = Dense(3,1)
grad_f(x) = gradient(x -> sum(model(x)),x)[1]
Zygote.jacobian(grad_f,rand(3))
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] Pullback
    @ .\iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] getindex
    @ .\tuple.jl:29 [inlined]
 [11] map
    @ .\tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:36 [inlined]
 [13] #1789#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\REPL[10]:1 [inlined]
 [20] (::typeof(∂(grad_f)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [21] #216
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined]
 [22] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(grad_f))}})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [23] Pullback
    @ .\operators.jl:1085 [inlined]
 [24] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [25] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [26] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [27] Pullback
    @ .\operators.jl:1085 [inlined]
 [28] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f))))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f)))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [30] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162
 [31] jacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140
 [32] top-level scope
    @ REPL[12]:1
 [33] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
ToucheSir commented 2 years ago

The first-order gradient is treating model as a global and trying to differentiate wrt it as well. I have no idea what the semantics of globals should be under reverse-over-reverse (we ought to make accum_global non-diff or ban it completely), so the easier path is to get rid of them entirely:

julia> grad_f_m(m) = x -> gradient(x -> sum(m(x)), x)[1]  # [edited to make name not clash]
grad_f_m (generic function with 1 method)

julia> grad_f_m(model)(rand(3))
3-element Vector{Float64}:
  0.28216660022735596
 -0.385393351316452
  0.08605238795280457

julia> Zygote.jacobian(grad_f_m(model), rand(3))[1]
3×3 Matrix{Float64}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

Assuming the all-zero jacobian is incorrect, it's perhaps the more "interesting" bug here.

mcabbott commented 2 years ago

Globals are evil.

I think zero is correct, as the model is linear in x. If you change it to use tanh then these all agree:

julia> model
Dense(3 => 1, tanh)  # 4 parameters

julia> rr = rand(3);

julia> ForwardDiff.jacobian(grad_f,rr)
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537

julia> ForwardDiff.jacobian(grad_f_m(model), rr)
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537

julia> Zygote.jacobian(grad_f_m(model), rr)[1]
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537
YichengDWu commented 2 years ago

I expect it to behave exactly like a pure function

julia>  grad_f(x) = gradient(x -> sum(x.^3),x)[1]
grad_f (generic function with 1 method)

julia> Zygote.jacobian(grad_f,rand(3))[1]
3×3 Matrix{Float64}:
 0.0456902  0.0       0.0
 0.0        0.449621  0.0
 0.0        0.0       4.86983
YichengDWu commented 2 years ago

We have more than one bug here 🥲, see #1264.

YichengDWu commented 2 years ago

Just copy things from there

julia> function f(x, bias)
              jac = Zygote.jacobian(x->x.^3, x)[1]
              return jac * x .+ bias
              end
f (generic function with 1 method)

julia> x,bias = rand(3),rand(3)
([0.2279638899624825, 0.6476786632858718, 0.13745627655377346], [0.051516386842686224, 0.6842360463718182, 0.22031281411507742])

julia> Zygote.gradient(b -> sum(f(x,b)), rand(3))
ERROR: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] _throw_mutation_error(f::Function, args::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
  [3] (::Zygote.var"#448#449"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:85
  [4] (::Zygote.var"#2506#back#450"{Zygote.var"#448#449"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [5] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:183 [inlined]
  [6] (::typeof(∂(_gradcopy!)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [7] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:165 [inlined]
  [8] (::typeof(∂(withjacobian)))(Δ::NamedTuple{(:val, :grad), Tuple{Nothing, Tuple{Matrix{Float64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [9] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(withjacobian))})(Δ::NamedTuple{(:val, :grad), Tuple{Nothing, Tuple{Matrix{Float64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [10] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [11] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140 [inlined]
 [12] (::typeof(∂(jacobian)))(Δ::Tuple{Matrix{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [13] Pullback
    @ .\REPL[20]:2 [inlined]
 [14] (::typeof(∂(f)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ .\REPL[22]:1 [inlined]
 [16] (::typeof(∂(#23)))(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof(∂(#23))})(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [18] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [19] top-level scope
    @ REPL[22]:1
 [20] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
YichengDWu commented 2 years ago

This might be a better example since layers in Lux should be treated exactly like pure functions

using Lux, Zygote, Random

model = Dense(3,1)
ps,st = Lux.setup(Random.default_rng(),model)
grad_f(x) = gradient(x -> sum(model(x,ps,st)[1]),x)[1]

Zygote.jacobian(grad_f,rand(3))
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] Pullback
    @ .\iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] getindex
    @ .\tuple.jl:29 [inlined]
 [11] map
    @ .\tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:36 [inlined]
 [13] #1789#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\REPL[11]:1 [inlined]
 [20] (::typeof(∂(grad_f)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [21] #216
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined]
 [22] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(grad_f))}})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [23] Pullback
    @ .\operators.jl:1085 [inlined]
 [24] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [25] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [26] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [27] Pullback
    @ .\operators.jl:1085 [inlined]
 [28] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f))))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f)))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [30] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162
 [31] jacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140
 [32] top-level scope
    @ REPL[13]:1
 [33] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
YichengDWu commented 2 years ago

I tried your method and it didn't work on Lux, how? @ToucheSir

using Lux, Zygote, Random

model = Dense(3,1)
ps,st = Lux.setup(Random.default_rng(),model)
grad_f(m,p,s) = x -> gradient(y -> sum(m(y,p,s)[1]),x)[1]

Zygote.jacobian(grad_f(model,ps,st),rand(3))
error ```julia ERROR: Can't differentiate foreigncall expression. You might want to check the Zygote limitations documentation. https://fluxml.ai/Zygote.jl/dev/limitations.html Stacktrace: [1] error(s::String) @ Base .\error.jl:33 [2] Pullback @ .\essentials.jl:599 [inlined] [3] (::typeof(∂(getindex)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [4] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\tools\builtins.jl:12 [inlined] [5] (::typeof(∂(literal_getindex)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [6] Pullback @ .\reflection.jl:752 [inlined] [7] (::typeof(∂(fieldcount)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [8] Pullback @ C:\Users\Luffy\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:220 [inlined] [9] (::typeof(∂(canonicalize)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [10] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:115 [inlined] [11] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:183 [inlined] [12] (::typeof(∂(_project)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [13] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:235 [inlined] [14] (::typeof(∂(λ)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [15] Pullback @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined] [16] (::typeof(∂(λ)))(Δ::Nothing) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [17] Pullback @ C:\Users\Luffy\.julia\packages\Lux\lEqCI\src\layers\basic.jl:631 [inlined] [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}, Nothing, Nothing}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [19] Pullback @ .\REPL[4]:1 [inlined] [20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [21] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined] [22] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [23] Pullback @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined] [24] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [25] Pullback @ .\REPL[4]:1 [inlined] [26] (::typeof(∂(λ)))(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [27] #216 @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined] [28] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(λ))}})(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [29] Pullback @ .\operators.jl:1085 [inlined] [30] (::typeof(∂(#_#83)))(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [31] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [32] #1909#back @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined] [33] Pullback @ .\operators.jl:1085 [inlined] [34] (::typeof(∂(λ)))(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0 [35] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [36] withjacobian(f::Function, args::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162 [37] jacobian(f::Function, args::Vector{Float64}) @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140 [38] top-level scope @ REPL[5]:1 [39] top-level scope @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52 ```
ToucheSir commented 2 years ago

Why would it? It's not like Lux has any special magic here. It's probably calling much the same code under the hood.

YichengDWu commented 2 years ago

It's probably calling much the same code under the hood.

But one works and one doesn't? Maybe I don't understand how Zygote differentiates a functor.

ToucheSir commented 2 years ago

I don't think it has anything to do with functors, but rather that Lux has getindex calls in the forward pass where Flux does not. It doesn't look like the calls in question, but from within Zygote's internals: given that https://github.com/FluxML/Zygote.jl/blob/cb59b6c780635a24afdc39c06f4de92ce4f52a0e/src/lib/lib.jl#L207 shows up in the stacktrace, perhaps there is some splatting happening in Lux layers?