LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
485 stars 58 forks source link

[3rd Order AD] Pullback over twice Jacobian #614

Open aksuhton opened 5 months ago

aksuhton commented 5 months ago

I'm looking to take two jacobians of the network neural network (with respect to its inputs) and then do parameter based optimization. Up front, I want to thank you for taking the time to read this, I was hesitant to post given that I probably ought to look somewhere other than nested AD for this purpose.

In my code proper, the pullback goes through and I find a scalar indexing error on back, but I've yet to reproduce this with an MWE. So I apologize if this is more of a "request for help" than an "issue".

At the moment, I get an undefined reference error on the following gradient:

using Zygote, Random, Lux, LinearAlgebra

model = @compact(; potential=Dense(5 => 5, gelu)) do x
    function jac_pot(x)
        return reshape(diag(only(Zygote.jacobian(potential, x))), size(x))
    end
    return reshape(diag(only(Zygote.jacobian(jac_pot, x))), size(x))
end

ps, st = Lux.setup(Random.default_rng(), model);
x = randn(Float32, 5, 3);
m_x, st_ = model(x, ps, st);

∂x, ∂ps, _ = Zygote.gradient(Base.Fix1(sum, abs2) ∘ first ∘ model, x, ps, st)

Stacktrace[1:37]:

1-element ExceptionStack:
Compiling Tuple{Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}: UndefRefError: access to undefined reference
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}, args::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81
  [3] Pullback
    @ ~/.julia/packages/LuxLib/VDD3J/src/api/dense.jl:52 [inlined]
  [4] Pullback
    @ ~/.julia/packages/LuxLib/VDD3J/src/api/dense.jl:38 [inlined]
  [5] Pullback
    @ ~/.julia/packages/Lux/ANzxX/src/layers/basic.jl:218 [inlined]
  [6] Pullback
    @ ~/.julia/packages/LuxCore/8lRV2/src/LuxCore.jl:180 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(LuxCore.apply), Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{typeof(Lux._getproperty), @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Val{:bias}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:bias, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux._vec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}, Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Val{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Val{true}, Vector{Float32}, Val{false}}, Tuple{Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Vector{Float32}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weight, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:activation, Zygote.Context{false}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, typeof(gelu)}}}}}}, args::Tuple{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Lux/ANzxX/src/helpers/stateful.jl:82 [inlined]
  [9] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Tuple{Zygote.Pullback{Tuple{typeof(setproperty!), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Symbol, @NamedTuple{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}, Zygote.var"#back#246"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, @NamedTuple{}}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:st, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:model, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, @NamedTuple{}}}, Zygote.var"#back#245"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.Pullback{Tuple{typeof(LuxCore.apply), Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{typeof(Lux._getproperty), @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Val{:bias}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:bias, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux._vec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}, Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Val{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Val{true}, Vector{Float32}, Val{false}}, Tuple{Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Vector{Float32}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weight, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:activation, Zygote.Context{false}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, typeof(gelu)}}}}}}}}, args::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxZygoteExt.jl:102 [inlined]
 [11] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:i, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Int64}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1586"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Tuple{Zygote.Pullback{Tuple{typeof(setproperty!), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Symbol, @NamedTuple{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}, Zygote.var"#back#246"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, @NamedTuple{}}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:st, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:model, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, @NamedTuple{}}}, Zygote.var"#back#245"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.Pullback{Tuple{typeof(LuxCore.apply), Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{typeof(Lux._getproperty), @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Val{:bias}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:bias, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux._vec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}, Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Val{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Val{true}, Vector{Float32}, Val{false}}, Tuple{Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Vector{Float32}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weight, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:activation, Zygote.Context{false}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, typeof(gelu)}}}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Zygote.var"#2989#back#768"{Zygote.var"#762#766"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:i, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Int64}}, Zygote.Pullback{Tuple{typeof(vec), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Tuple{Int64}}}}}}}, args::ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [12] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [13] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:i, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Int64}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1586"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Tuple{UnitRange{Int64}}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Tuple{Zygote.Pullback{Tuple{typeof(setproperty!), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Symbol, @NamedTuple{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}, Zygote.var"#back#246"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, @NamedTuple{}}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:st, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:model, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, @NamedTuple{}}}, Zygote.var"#back#245"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.Pullback{Tuple{typeof(LuxCore.apply), Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{typeof(Lux._getproperty), @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Val{:bias}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:bias, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux._vec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}, Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(fused_dense_bias_activation), typeof(gelu), Matrix{Float32}, Val{false}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Val{true}, Vector{Float32}, Val{false}}, Tuple{Zygote.Pullback{Tuple{typeof(LuxLib.__generic_dense_bias_activation), typeof(gelu), Matrix{Float32}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Vector{Float32}}, Any}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Vector{Float32}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}}, Zygote.ZBack{LuxLib.var"#__is_immutable_array_or_dual_val_pullback#32"{Tuple{Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weight, Zygote.Context{false}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Matrix{Float32}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:activation, Zygote.Context{false}, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, typeof(gelu)}}}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Zygote.var"#2989#back#768"{Zygote.var"#762#766"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:i, Zygote.Context{false}, LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Int64}}, Zygote.Pullback{Tuple{typeof(vec), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, Tuple{Int64}}}}}}}}, args::ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [14] gradient
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, ::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [16] #16
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxZygoteExt.jl:104 [inlined]
 [17] _pullback(::Zygote.Context{false}, ::LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Float32}, Float32, 1}}, ::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [18] __forwarddiff_jvp
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxForwardDiffExt.jl:32 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(Lux.__forwarddiff_jvp), ::LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, ::Matrix{Float32}, ::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{}}, ::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [20] #14
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxZygoteExt.jl:103 [inlined]
 [21] _pullback(ctx::Zygote.Context{false}, f::LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, args::Tuple{Int64, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [22] MappingRF
    @ ./reduce.jl:100 [inlined]
 [23] _foldl_impl
    @ ./reduce.jl:58 [inlined]
 [24] _pullback(::Zygote.Context{false}, ::typeof(Base._foldl_impl), ::Base.MappingRF{LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Base.BottomRF{typeof(Lux.__internal_add)}}, ::Base._InitialValue, ::Base.Iterators.Enumerate{Vector{SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [25] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [26] _pullback(::Zygote.Context{false}, ::typeof(Base.foldl_impl), ::Base.MappingRF{LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, Base.BottomRF{typeof(Lux.__internal_add)}}, ::Base._InitialValue, ::Base.Iterators.Enumerate{Vector{SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [28] _pullback(::Zygote.Context{false}, ::typeof(Base.mapfoldl_impl), ::LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::typeof(Lux.__internal_add), ::Base._InitialValue, ::Base.Iterators.Enumerate{Vector{SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [29] #mapfoldl#298
    @ ./reduce.jl:175 [inlined]
 [30] _pullback(::Zygote.Context{false}, ::Base.var"##mapfoldl#298", ::Base._InitialValue, ::typeof(mapfoldl), ::LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::typeof(Lux.__internal_add), ::Base.Iterators.Enumerate{Vector{SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] mapfoldl
    @ ./reduce.jl:175 [inlined]
 [32] _pullback(::Zygote.Context{false}, ::typeof(mapfoldl), ::LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::typeof(Lux.__internal_add), ::Base.Iterators.Enumerate{Vector{SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, Vector{Float32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [33] mapreduce
    @ ./reduce.jl:307 [inlined]
 [34] #13
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxZygoteExt.jl:101 [inlined]
 [35] _pullback(ctx::Zygote.Context{false}, f::LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Tuple{Matrix{Float32}}}, args::ChainRulesCore.Tangent{Any, Tuple{Diagonal{Float32, Vector{Float32}}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [36] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [37] Pullback
    @ ~/.julia/packages/Lux/ANzxX/ext/LuxZygoteExt.jl:81 [inlined]
aksuhton commented 5 months ago

Stacktrace[38:end]

 [38] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}}, Tuple{Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Tuple{Matrix{Float32}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}}}}, args::Tuple{Diagonal{Float32, Vector{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [39] Pullback
    @ ./none:0 [inlined]
 [40] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#size_pullback#917"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Vector{Float32}, Tuple{Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}}, Tuple{Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Tuple{Matrix{Float32}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:potential, Zygote.Context{false}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Zygote.ZBack{ChainRules.var"#diag_pullback#2059"}, Zygote.Pullback{Tuple{typeof(only), Tuple{Matrix{Float32}}}, Tuple{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{1, 1, Zygote.Context{false}, Matrix{Float32}}}}}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [41] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [42] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#size_pullback#917"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Vector{Float32}, Tuple{Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}}, Tuple{Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Tuple{Matrix{Float32}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:potential, Zygote.Context{false}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Zygote.ZBack{ChainRules.var"#diag_pullback#2059"}, Zygote.Pullback{Tuple{typeof(only), Tuple{Matrix{Float32}}}, Tuple{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{1, 1, Zygote.Context{false}, Matrix{Float32}}}}}}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [43] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [44] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#size_pullback#917"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Vector{Float32}, Tuple{Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}}, Tuple{Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, Matrix{Float32}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, Tuple{Matrix{Float32}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:potential, Zygote.Context{false}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Zygote.ZBack{ChainRules.var"#diag_pullback#2059"}, Zygote.Pullback{Tuple{typeof(only), Tuple{Matrix{Float32}}}, Tuple{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{1, 1, Zygote.Context{false}, Matrix{Float32}}}}}}}}}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [45] Pullback
    @ ./operators.jl:1045 [inlined]
 [46] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [47] Pullback
    @ ./operators.jl:1044 [inlined]
 [48] Pullback
    @ ./operators.jl:1041 [inlined]
 [49] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Tuple{Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [50] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [51] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Tuple{Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [52] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [53] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Tuple{Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}}}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [54] Pullback
    @ ./operators.jl:1041 [inlined]
 [55] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.var"#2366#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), @NamedTuple{}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Tuple{Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [56] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [57] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.var"#2366#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), @NamedTuple{}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#749"}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Float32}, Tuple{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Tuple{Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}, var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [58] withjacobian
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:150 [inlined]
 [59] _pullback(::Zygote.Context{false}, ::typeof(withjacobian), ::var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [60] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [61] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [62] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [63] jacobian
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:128 [inlined]
 [64] _pullback(::Zygote.Context{false}, ::typeof(jacobian), ::var"#jac_pot#45"{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, @NamedTuple{}}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [65] #43
    @ ./none:0 [inlined]
 [66] _pullback(::Zygote.Context{false}, ::var"#43#44", ::@NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::@NamedTuple{potential::@NamedTuple{}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [67] CompactLuxLayer
    @ ~/.julia/packages/Lux/ANzxX/src/helpers/compact.jl:422 [inlined]
 [68] _pullback(::Zygote.Context{false}, ::CompactLuxLayer{nothing, var"#43#44", @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, Lux.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::@NamedTuple{potential::@NamedTuple{}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [69] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [70] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [71] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [72] call_composed
    @ ./operators.jl:1045 [inlined]
 [73] call_composed (repeats 2 times)
    @ ./operators.jl:1044 [inlined]
 [74] #_#103
    @ ./operators.jl:1041 [inlined]
 [75] _pullback(::Zygote.Context{false}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{ComposedFunction{Base.Fix1{typeof(sum), typeof(abs2)}, typeof(first)}, CompactLuxLayer{nothing, var"#43#44", @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, Lux.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::@NamedTuple{potential::@NamedTuple{}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [76] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [77] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [78] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [79] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [80] _pullback(::Zygote.Context{false}, ::ComposedFunction{ComposedFunction{Base.Fix1{typeof(sum), typeof(abs2)}, typeof(first)}, CompactLuxLayer{nothing, var"#43#44", @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, Lux.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::@NamedTuple{potential::@NamedTuple{}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [81] pullback(::Function, ::Zygote.Context{false}, ::Matrix{Float32}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [82] pullback(::Function, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88
 [83] gradient(::Function, ::Matrix{Float32}, ::@NamedTuple{potential::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
avik-pal commented 5 months ago

:sweat: this is the third order requiring a forward over an HVP or a Reverse over Taylor Mode (TaylorDiff.jl) to be remotely efficient.

Firstly I would rewrite the function to use some sort of a custom hessian because you just need the diagonals so Jacobian over Jacobian is a bad idea (performance wise). There was some AAAI paper (I think) which showed how to do diagonal of 2nd order very fast, but I can;t seem to locate the paper rn.

That said, it would still require a custom rrule where you use forwarddiff over HVP. I don't think anyone is actively working on 3rd order (including me), so all I can do is point you to the code which might be helpful https://github.com/LuxDL/Lux.jl/blob/afd8555a1df4ddba211a9fb860d5db84b5d91ba4/ext/LuxZygoteExt.jl#L42-L63

avik-pal commented 5 months ago

Actually @tansongchen might know how to do these kind of 3rd order differentiation efficiently.

aksuhton commented 5 months ago

Thank you @avik-pal, sorry for the 😓.

avik-pal commented 5 months ago

No worries, I do think it would be a nice to have feature, if we can get a general implementation for this up and running.

The :sweat: was for Zygote giving bad error messages. For nested reverse, Zygote becomes type unstable (beyond the simple cases) and then it throws that undefined reference error which can't be parsed by anyone who doesn't already know what is going wrong (i.e. pretty much the opposite of what an error message should do)

tansongchen commented 5 months ago

Could you write the mathematical expression for the derivative you want to get? When I hear keywords like "diagonal", "forward over reverse" stuff, I have some confidence to reformulate it to make use of TaylorDiff 🤔

aksuhton commented 5 months ago

Hey @tansongchen, thanks for joining us! Absolutely. I'm trying to learn a mapping between two time-sequences

$\varepsilon_t \rightarrow \sigma_t$.

Now, thermodynamic arguments suggest that

$\sigma_t = \frac{\partial F_t}{\partial \varepsilon_t}$

for some function $F_t$ that depends on the history of $\varepsilon$ up to time $t$. The recent nested AD updates let me seek $F_t$ here as a temporal convolutional neural network that takes as input the time-sequence $\varepsilon_t$. Yay! However, this still misses some physical structure.

Ideally, I'd seek $F_t$ via

$\partial_t \sigma_t = \frac{\partial^2 F_t}{\partial \varepsilon_t^2}\partial_t \varepsilon_t$

as this makes it manifest that the stress $\sigma$ remains fixed whenever the strain $\varepsilon$ remains fixed.

There's (at least) two difficulties here:

  1. I want to take two derivatives of the network output with respect to its inputs before taking another derivative for the purpose of optimizing the network parameters.
  2. The neural network output is not a number, rather I'd like to output a sequence

$F = ({F_1[\varepsilon], F_2[\varepsilon], ..., F_T[\varepsilon]})$

and then collect the derivatives

$\partial^2F = ({\partial^2_{\varepsilon_1}F1[\varepsilon], \partial^2{\varepsilon_2}F2[\varepsilon], ..., \partial^2{\varepsilon_T}F_T[\varepsilon]})$.

Please note my shorthand for "functional" dependence

$F_i[\varepsilon] = F_i(\varepsilonT, \varepsilon{T-1}, ..., \varepsilon_1)$.

Similar structure comes up in Equation (22a) of the following paper, though they use a set of internal variables $\zeta_t$ to capture dynamics that I'd rather implicitly handle through a $t$ subscript on $F$:

https://arxiv.org/abs/2005.12183

avik-pal commented 4 months ago

@tansongchen do you think we can rewrite this in-terms of TaylorDiff?

tansongchen commented 4 months ago

It would be easy to do once I support chunking (surely will be doing that in a month or two). The model can then be evaluated in one pass of TaylorDiff, so this will be a standard Zygote-over-TaylorDiff overlay which is well understood

avik-pal commented 3 months ago

Now that we have better Enzyme support trying this out with Enzyme might be worthwhile. Right now, it will still be messy, but hopefully, #738 will make life easier here.