LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Nested AD for Parameter Gradient/Jacobian #610

Closed prbzrg closed 2 months ago

prbzrg commented 2 months ago

I'm getting this error for a code that one month ago was working:

ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})

Closest candidates are:
  fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
   @ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22

Stacktrace:
   [1] macro expansion
     @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
   [2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
     @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
   [3] __activation_gradient
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
   [4] LuxDL/LuxLib.jl#44
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
   [5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{Float32}, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Matrix{Float32}, Nothing}, args::Matrix{Float32})
     @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
   [6] ZBack
     @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
   [7] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
   [8] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:31 [inlined]
   [9] Pullback
     @ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
  [10] Pullback
     @ C:\Users\prbzr\.julia\packages\LuxCore\8lRV2\src\LuxCore.jl:180 [inlined]
  [b2108857] Lux v0.5.40
  [82251201] LuxLib v0.3.18
  [bb33d45b] LuxCore v0.1.14

I will update this!

prbzrg commented 2 months ago

downgrading to Lux v0.5.37 worked.

prbzrg commented 2 months ago

Error:

ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})

Closest candidates are:
  fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
   @ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22

Stacktrace:
  [1] macro expansion
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
  [3] __activation_gradient
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
  [4] LuxDL/LuxLib.jl#44
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
  [5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{…}, Base.ReshapedArray{…}, Matrix{…}, SubArray{…}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [6] ZBack
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [7] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
  [8] Pullback
    @ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:38 [inlined]
  [9] Pullback
    @ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Matrix{…}, Nothing})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [11] Pullback
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [13] #291
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [15] #2169#back
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [17] Pullback
    @ .\operators.jl:1045 [inlined]
 [18] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#1#2"}, Tuple{ComponentVector{Float32, Vector{…}, Tuple{…}}}, @Kwargs{}}, Any}, args::Matrix{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\operators.jl:1044 [inlined]
 [20] Pullback
    @ .\operators.jl:1041 [inlined]
 [21] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{…}, ComponentVector{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] #291
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [24] #2169#back
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [25] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [26] Pullback
    @ .\operators.jl:1041 [inlined]
 [27] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [28] #75
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
 [29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [30] withjacobian
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:150 [inlined]
 [31] _pullback(::Zygote.Context{false}, ::typeof(withjacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [32] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:838
 [33] adjoint
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:203 [inlined]
 [34] _pullback
    @ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [35] jacobian
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:128 [inlined]
 [36] _pullback(::Zygote.Context{false}, ::typeof(jacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [37] fn1
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
 [38] _pullback(ctx::Zygote.Context{false}, f::typeof(fn1), args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [39] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:90
 [40] pullback
    @ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:88 [inlined]
 [41] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
    @ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:147
 [42] top-level scope
    @ D:\Codes\Mine\bug-report\br-3\br-3.jl:13
 [43] include(fname::String)
    @ Base.MainInclude .\client.jl:489
 [44] top-level scope
    @ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3.jl:13
Some type information was truncated. Use `show(err)` to see complete types.

MRE:

using ComponentArrays, Lux, Random, Zygote

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)

function fn1(z)
    sum(first(Zygote.jacobian(x -> first(nn(r, x, st)), z)))
end

fn1(ps)
Zygote.gradient(fn1, ps)

Environment:

Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
  [b0b7db55] ComponentArrays v0.15.11
  [b2108857] Lux v0.5.40
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random
avik-pal commented 2 months ago

Yeah that is some weird Zygote broadcast handling quirk.

From v0.5.40 we use completely different backend operations which are faster and allocate significantly less but come at the cost of sacrificing nested reverse over reverse zygote AD (which to be fair, worked only in very limited cases and was never documented for a good reason)

https://lux.csail.mit.edu/stable/manual/nested_autodiff does nested AD for the inputs, but the same for parameters hasn't been implemented yet.

vavrines commented 2 months ago

I met the same issue when taking Zygote gradient of pullback

using Lux, Zygote, ComponentArrays, Random

X = collect(range(0, 1, length = 10)) |> permutedims
Y = zeros(axes(X))

nn = Chain(Dense(1 => 20, tanh), Dense(20 => 1))
ps, st = Lux.setup(Xoshiro(0), nn)
pv = ComponentArray(ps)

function loss(p)
    u(x) = 1 .+ nn(x, p, st)[1] .* x
    ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]

    pred = ux(X)
    loss = sum(abs2, pred)

    return loss
end

Zygote.gradient(loss, pv)

The code doesn't work since v0.5.40.

prbzrg commented 2 months ago

I added

[Lux]
DisableAutomaticNestedADSwitching = true

as LocalPreferences.toml , but it didn't change anything.

prbzrg commented 2 months ago

And also I didn't use StatefulLuxLayer, so I don't think the "Nested Automatic Differentiation" get activated.

avik-pal commented 2 months ago

I met the same issue when taking Zygote gradient of pullback

This is a separate issue. See ~https://github.com/LuxDL/Lux.jl/issues/544~ https://github.com/LuxDL/Lux.jl/issues/600. FWIW pullback gradients are used extensively in DeepEquilibriumModels, see https://github.com/SciML/DeepEquilibriumNetworks.jl/blob/main/ext/DeepEquilibriumNetworksZygoteExt.jl. We just need to do the overload for DI.pullback.

The core problem is still the same. Zygote modifies Broadcast.broadcasted operations in a strange way that doesn't allow using FastBroadcast and such. (This part has nothing to do with the Nested AD rules that were introduced but rather https://github.com/LuxDL/Lux.jl/pull/591). To fix this we just need to introduce an rrule for the parameter jacobian on Base.Fix1(::StatefulLuxLayer, x) similar to how the other jacobian and gradient calls are captured.