Closed prbzrg closed 2 months ago
downgrading to Lux v0.5.37
worked.
Error:
ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})
Closest candidates are:
fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
@ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22
Stacktrace:
[1] macro expansion
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
[3] __activation_gradient
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
[4] LuxDL/LuxLib.jl#44
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
[5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{…}, Base.ReshapedArray{…}, Matrix{…}, SubArray{…}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[6] ZBack
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
[7] Pullback
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
[8] Pullback
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:38 [inlined]
[9] Pullback
@ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
[10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Matrix{…}, Nothing})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[11] Pullback
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
[12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[13] #291
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[15] #2169#back
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[16] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[17] Pullback
@ .\operators.jl:1045 [inlined]
[18] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#1#2"}, Tuple{ComponentVector{Float32, Vector{…}, Tuple{…}}}, @Kwargs{}}, Any}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[19] Pullback
@ .\operators.jl:1044 [inlined]
[20] Pullback
@ .\operators.jl:1041 [inlined]
[21] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{…}, ComponentVector{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[22] #291
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[23] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[24] #2169#back
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[25] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[26] Pullback
@ .\operators.jl:1041 [inlined]
[27] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[28] #75
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
[29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[30] withjacobian
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:150 [inlined]
[31] _pullback(::Zygote.Context{false}, ::typeof(withjacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[32] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:838
[33] adjoint
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:203 [inlined]
[34] _pullback
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
[35] jacobian
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:128 [inlined]
[36] _pullback(::Zygote.Context{false}, ::typeof(jacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[37] fn1
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
[38] _pullback(ctx::Zygote.Context{false}, f::typeof(fn1), args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[39] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:90
[40] pullback
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:88 [inlined]
[41] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:147
[42] top-level scope
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:13
[43] include(fname::String)
@ Base.MainInclude .\client.jl:489
[44] top-level scope
@ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3.jl:13
Some type information was truncated. Use `show(err)` to see complete types.
MRE:
using ComponentArrays, Lux, Random, Zygote
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
function fn1(z)
sum(first(Zygote.jacobian(x -> first(nn(r, x, st)), z)))
end
fn1(ps)
Zygote.gradient(fn1, ps)
Environment:
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
[b0b7db55] ComponentArrays v0.15.11
[b2108857] Lux v0.5.40
[e88e6eb3] Zygote v0.6.69
[9a3f8284] Random
Yeah that is some weird Zygote broadcast handling quirk.
From v0.5.40 we use completely different backend operations which are faster and allocate significantly less but come at the cost of sacrificing nested reverse over reverse zygote AD (which to be fair, worked only in very limited cases and was never documented for a good reason)
https://lux.csail.mit.edu/stable/manual/nested_autodiff does nested AD for the inputs, but the same for parameters hasn't been implemented yet.
I met the same issue when taking Zygote gradient of pullback
using Lux, Zygote, ComponentArrays, Random
X = collect(range(0, 1, length = 10)) |> permutedims
Y = zeros(axes(X))
nn = Chain(Dense(1 => 20, tanh), Dense(20 => 1))
ps, st = Lux.setup(Xoshiro(0), nn)
pv = ComponentArray(ps)
function loss(p)
u(x) = 1 .+ nn(x, p, st)[1] .* x
ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]
pred = ux(X)
loss = sum(abs2, pred)
return loss
end
Zygote.gradient(loss, pv)
The code doesn't work since v0.5.40.
I added
[Lux]
DisableAutomaticNestedADSwitching = true
as LocalPreferences.toml
, but it didn't change anything.
And also I didn't use StatefulLuxLayer
, so I don't think the "Nested Automatic Differentiation" get activated.
I met the same issue when taking Zygote gradient of pullback
This is a separate issue. See ~https://github.com/LuxDL/Lux.jl/issues/544~ https://github.com/LuxDL/Lux.jl/issues/600. FWIW pullback gradients are used extensively in DeepEquilibriumModels, see https://github.com/SciML/DeepEquilibriumNetworks.jl/blob/main/ext/DeepEquilibriumNetworksZygoteExt.jl. We just need to do the overload for DI.pullback
.
The core problem is still the same. Zygote modifies Broadcast.broadcasted operations in a strange way that doesn't allow using FastBroadcast and such. (This part has nothing to do with the Nested AD rules that were introduced but rather https://github.com/LuxDL/Lux.jl/pull/591). To fix this we just need to introduce an rrule for the parameter jacobian on Base.Fix1(::StatefulLuxLayer, x)
similar to how the other jacobian and gradient calls are captured.
I'm getting this error for a code that one month ago was working:
I will update this!