JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
237 stars 41 forks source link

Add num_types_in_tuple definitionUpdate vecjac_products.jl #162

Closed QiyaoWei closed 2 years ago

QiyaoWei commented 2 years ago

Update vecjac_products.jl

QiyaoWei commented 2 years ago

After adding the definition, runs into problem

ERROR: LoadError: MethodError: no method matching size(::Base.MethodList)
Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:558
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:557
  size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/cholesky.jl:442
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:89 [inlined]
  [2] _tryaxes(x::Base.MethodList)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:184
  [3] map
    @ ./tuple.jl:213 [inlined]
  [4] ∇map(cx::Zygote.Context, f::SparseDiffTools.var"#62#63"{DataType, DataType, DataType, DataType, DataType, DataType}, args::Base.MethodList)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:199
  [5] _pullback(cx::Zygote.Context, #unused#::typeof(collect), g::Base.Generator{Base.MethodList, SparseDiffTools.var"#62#63"{DataType, DataType, DataType, DataType, DataType, DataType}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:244
  [6] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:8 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(SparseDiffTools._numargs), args::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:32 [inlined]
  [9] _pullback(::Zygote.Context, ::SparseDiffTools.var"##num_vecjac!#64", ::Bool, ::typeof(num_vecjac!), ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/SparseDiffTools/3pVbY/src/differentiation/vecjac_products.jl:32 [inlined]
 [11] _pullback(::Zygote.Context, ::SparseDiffTools.var"#num_vecjac!##kw", ::NamedTuple{(:compute_f0,), Tuple{Bool}}, ::typeof(num_vecjac!), ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::DiffEqOperators.var"#432#440"{VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/.julia/packages/DiffEqOperators/lqggZ/src/vecjac_operators.jl:175 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(LinearAlgebra.mul!), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/DiffEqOperators/lqggZ/src/vecjac_operators.jl:104 [inlined]
 [15] _pullback(::Zygote.Context, ::typeof(*), ::VecJacOperator{Float32, var"#41#42"{MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Any, Bool}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [16] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:222 [inlined]
 [17] _pullback(::Zygote.Context, ::var"##_#40", ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, ::MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [18] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:211 [inlined]
 [19] _pullback(::Zygote.Context, ::SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, ::MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [20] _pullback
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:346 [inlined]
 [21] _pullback(::Zygote.Context, ::var"#49#54"{SupervisedLossContainer{typeof(Flux.Losses.logitcrossentropy), Float32}, MnistWidthStackedDEQ{false, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetwork{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, DeepEquilibriumNetw 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{ResNetLayer{Conv{2, 2, typeof(relu), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, Conv{2, 2, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Flux.Zeros}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}, GroupNorm{typeof(identity), Nothing, Float32, Nothing}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, SteadyStateAdjoint{0, true, Val{:central}, ZygoteVJP, LinSolveKrylovJL{typeof(Krylov.gmres), Tuple{}, Float32, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:itmax,), Tuple{Int64}}}}}, Base.Iterators.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:maxiters, :verbose), Tuple{Int64, Bool}}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, Chain{Tuple{BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, MaxPool{2, 4}}}, BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#37#39", Chain{Tuple{typeof(flatten), Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [22] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:351
 [23] train(config::Dict{String, Any})
    @ Main /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:346
 [24] top-level scope
    @ /workspace/FastDEQ.jl-main/experiments/mnist_deq.jl:443
 [25] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [26] top-level scope
    @ REPL[6]:1
 [27] top-level scope
    @ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52
ChrisRackauckas commented 2 years ago

I don't understand what this PR is doing. @avik-pal ? Just research right now?

avik-pal commented 2 years ago

Since we didn't want the SciMLBase dep, we need to copy over a few of these functions from there

codecov-commenter commented 2 years ago

Codecov Report

Merging #162 (bffee4f) into master (22f8089) will decrease coverage by 0.42%. The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #162      +/-   ##
==========================================
- Coverage   79.12%   78.69%   -0.43%     
==========================================
  Files          14       14              
  Lines         733      737       +4     
==========================================
  Hits          580      580              
- Misses        153      157       +4     
Impacted Files Coverage Δ
src/differentiation/vecjac_products.jl 0.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 22f8089...bffee4f. Read the comment docs.