JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
237 stars 41 forks source link

Report Weird Behavior #164

Closed QiyaoWei closed 2 years ago

QiyaoWei commented 2 years ago

When I write code that looks like

    A = VecJacOperator(some_model, data_input, param; autodiff=false)
    println(A*vec(ε)) #ε is a random vector

The code here (https://github.com/JuliaDiff/SparseDiffTools.jl/blob/09bc8b1cc1202c93f2a23f8c8497a8c6e1365bdf/src/differentiation/vecjac_products.jl#L8) errs

ERROR: LoadError: MethodError: no method matching size(::Base.MethodList)
Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:558
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:557
  size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/cholesky.jl:442
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:89 [inlined]
  [2] _tryaxes(x::Base.MethodList)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:184
  [3] map
    @ ./tuple.jl:213 [inlined]
  [4] ∇map(cx::Zygote.Context, f::SparseDiffTools.var"#62#63"{DataType, DataType, DataType, DataType, DataType, DataType}, args::Base.MethodList)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:199
  [5] _pullback(cx::Zygote.Context, #unused#::typeof(collect), g::Base.Generator{Base.MethodList, SparseDiffTools.var"#62#63"{DataType, DataType, DataType, DataType, DataType, DataType}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:244
  [6] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:8 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(SparseDiffTools._numargs), args::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:32 [inlined]
  [9] _pullback(::Zygote.Context, ::SparseDiffTools.var"##num_vecjac!#64", ::Bool, ::typeof(num_vecjac!), ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:32 [inlined]
 [11] _pullback(::Zygote.Context, ::SparseDiffTools.var"#num_vecjac!##kw", ::NamedTuple{(:compute_f0,), Tuple{Bool}}, ::typeof(num_vecjac!), ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/.julia/packages/DiffEqOperators/lqggZ/src/vecjac_operators.jl:175 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(LinearAlgebra.mul!), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/DiffEqOperators/lqggZ/src/vecjac_operators.jl:104 [inlined]
 [15] _pullback(::Zygote.Context, ::typeof(*), ::VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [16] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:222 [inlined]
 [17] _pullback(::Zygote.Context, ::var"##_#40", ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, ::MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [18] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:211 [inlined]
 [19] _pullback(::Zygote.Context, ::SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, ::MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [20] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:346 [inlined]
 [21] _pullback(::Zygote.Context, ::var"#49#54"{SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetw 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [22] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:351
 [23] train(config::Dict{String, Any})
    @ Main /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:346
 [24] top-level scope
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:443
 [25] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [26] top-level scope
    @ REPL[6]:1
 [27] top-level scope
    @ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52

So I rewrite that line (line 8) in the obvious way

    max = 0
    for m in methods(f)
        temp = (m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig)
        if temp > max
            max = temp
        end
    end
    numparam = max

And the error goes away. Not sure why

ChrisRackauckas commented 2 years ago

That function should just be defined as nograd.

QiyaoWei commented 2 years ago

You mean VecJacoperator? But what if the jacobian is part of the training loss?

ChrisRackauckas commented 2 years ago

No, _numargs(f)

QiyaoWei commented 2 years ago

Gotcha. Wrapping the function around Zygote.ignore() does take away this error. Out of curiosity, what is the most elegant way to call a function with nograd? Currently I am doing it like this (which works), but not sure if this code looks good

function wrap(f)
    Zygote.ignore() do
        _numargs(f)
    end
end
temp = wrap(f)
#do other stuff
ChrisRackauckas commented 2 years ago

I think you can just add @nograd on the definition of _numargs. @oxinabox is there a ChainRulesCore version of Zygote.@nograd?

oxinabox commented 2 years ago

ChainRulesCore.@non_differentiable

ChrisRackauckas commented 2 years ago

Did a PR get made for this?

ChrisRackauckas commented 2 years ago

Turns out, looking closer, the proper fix is https://github.com/JuliaDiff/SparseDiffTools.jl/pull/166