LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Differentiating `Zygote.pullback` #621

Closed prbzrg closed 2 months ago

prbzrg commented 2 months ago

Error:

ERROR: LoadError: Mutating arrays is not supported -- called setindex!(Vector{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float32}})(::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float32}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [5] _mapreducedim!
    @ .\reducedim.jl:317 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(Base._mapreducedim!), typeof(identity), typeof(Base.add_sum), Vector{Float32}, Matrix{Float32}}, Any})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [7] mapreducedim!
    @ .\reducedim.jl:324 [inlined]
  [8] #sum!#852
    @ .\reducedim.jl:1034 [inlined]
  [9] (::Zygote.Pullback{Tuple{Base.var"##sum!#852", Bool, typeof(sum!), typeof(identity), Vector{Float32}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Any}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [10] sum!
    @ .\reducedim.jl:1034 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{init::Bool}, typeof(sum!), typeof(identity), Vector{Float32}, Matrix{Float32}}, Any})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #sum!#853
    @ .\reducedim.jl:1036 [inlined]
 [13] (::Zygote.Pullback{Tuple{Base.var"##sum!#853", Bool, typeof(sum!), Vector{…}, Matrix{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}, Zygote.var"#2013#back#204"{…}}})(Δ::Nothing)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [14] sum!
    @ .\reducedim.jl:1036 [inlined]
 [15] __added_bias_gradient
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\utils.jl:184 [inlined]
 [16] __matmul_bias_partials
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\impl\fused_dense.jl:78 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [18] #46
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\impl\fused_dense.jl:47 [inlined]
 [19] (::Zygote.Pullback{Tuple{LuxLib.var"#46#49"{…}, Matrix{…}}, Any})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{Float32, 2, Tuple{…}}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] ZBack
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\api\dense.jl:46 [inlined]
 [23] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [24] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\VDD3J\src\api\dense.jl:38 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [26] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ErEns\src\layers\basic.jl:218 [inlined]
 [27] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [28] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxCore\8lRV2\src\LuxCore.jl:180 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [30] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ErEns\src\helpers\stateful.jl:83 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [32] Pullback
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:10 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [34] #75
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
 [35] (::Zygote.Pullback{Tuple{Zygote.var"#75#76"{…}, Matrix{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#2180#back#303"{…}, Zygote.Pullback{…}}})(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [36] fn1
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:11 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [38] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [39] gradient(::Function, ::Matrix{Float32}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [40] top-level scope
    @ D:\Codes\Mine\bug-report\br-3\br-3-3.jl:15
 [41] include(fname::String)
    @ Base.MainInclude .\client.jl:489
 [42] top-level scope
    @ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3-3.jl:15
Some type information was truncated. Use `show(err)` to see complete types.

MRE:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, back = Zygote.pullback(x -> snn(x, p), u)
    sum(only(back(z))) + sum(z)
end

fn1(r, ps)
Zygote.gradient(fn1, r, ps)

Environment:

Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.41
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random
prbzrg commented 2 months ago

It works with Lux v0.5.37. How can I do it with Lux v0.5.41? Is this a bug? Did I miss something?

prbzrg commented 2 months ago

Maybe related to https://github.com/LuxDL/Lux.jl/issues/286

avik-pal commented 2 months ago

pullbacks are hard to differentiate directly. See https://github.com/LuxDL/Lux.jl/issues/610#issuecomment-2085160267. We just need some rrules for DifferentiationInterface.pullback. Zygote.pullback is almost never going to work, unless someone can use some nice trick to write the tangent for the pullback function. DI.pullback on the other hand is quite simple, DEQs.jl already does that

prbzrg commented 2 months ago

I'm still getting the error with

(br-3) pkg> st
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [a0c0ee7d] DifferentiationInterface v0.3.3
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.42
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random

(br-3) pkg> st --outdated
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`

(br-3) pkg> st --outdated -m
Status `D:\Codes\Mine\bug-report\br-3\Manifest.toml`

Even after using DI:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, uJ = DifferentiationInterface.value_and_pullback(x -> snn(x, p), AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps)
DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
avik-pal commented 2 months ago

Use Lux.vector_jacobian_product

prbzrg commented 2 months ago

I tried:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)

function fn1(u, p)
    z, uJ = Lux.vector_jacobian_product(x -> snn(x, p), AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps)
# DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps)

But the problem is still there. 😬

avik-pal commented 2 months ago

Closures don't work, see the first part in https://lux.csail.mit.edu/stable/manual/nested_autodiff. Also https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.vector_jacobian_product returns only the vjp not the value and vjp (which can be added later but doesn't affect the code by much)

prbzrg commented 2 months ago

Thanks, the problem is resolved. final code:

using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)

function fn1(u, ps, st)
    snn = StatefulLuxLayer(nn, ps, st)
    z = snn(u)
    uJ = Lux.vector_jacobian_product(snn, AutoZygote(), u, u)
    sum(uJ) + sum(z)
end

fn1(r, ps, st)
# DifferentiationInterface.gradient(x -> fn1(r, x, st), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps, st)
avik-pal commented 2 months ago

Capturing DI would have been the ideal situation but it causes ambiguities and I would have to manually define the functions for all possibilities which will get messy https://github.com/LuxDL/Lux.jl/issues/600#issuecomment-2094566210