JuliaDiff / TaylorDiff.jl

Taylor-mode automatic differentiation for higher-order derivatives
https://juliadiff.org/TaylorDiff.jl/
MIT License
73 stars 8 forks source link

Compatibility with ComponentArrays and Zygote #37

Closed YichengDWu closed 1 year ago

YichengDWu commented 1 year ago

Hi thanks for creating this pacakge, I'm getting the following error, could you shared some insights?

julia> using ComponentArrays, TaylorDiff, Zygote

julia> x = rand(2);

julia> ps = (W = rand(16,2),);

julia> mlp(x,ps) = ps.W * x;

julia> gradient(θ->TaylorDiff.derivative(c->sum(mlp(c,θ)), x, [1.0, 2.0], 1), ps)
((W = [1.0 2.0; 1.0 2.0; … ; 1.0 2.0; 1.0 2.0],),)

julia> gradient(θ->TaylorDiff.derivative(c->sum(mlp(c,θ)), x, [1.0, 2.0], 1), ComponentArray(ps))
ERROR: Mutating arrays is not supported -- called setindex!(Vector{TaylorScalar{Float64, 2}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{TaylorScalar{Float64, 2}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
  [3] (::Zygote.var"#551#552"{Vector{TaylorScalar{Float64, 2}}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:100
  [4] (::Zygote.var"#2643#back#553"{Zygote.var"#551#552"{Vector{TaylorScalar{Float64, 2}}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [5] Pullback
    @ ~/.julia/juliaup/julia-1.9.0+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:812 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any})(Δ::FillArrays.Fill{TaylorScalar{Float64, 2}, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/juliaup/julia-1.9.0+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:81 [inlined]
  [8] Pullback
    @ ~/.julia/juliaup/julia-1.9.0+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
  [9] Pullback
    @ ~/.julia/juliaup/julia-1.9.0+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:56 [inlined]
 [10] (::Zygote.Pullback{Tuple{typeof(*), Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, Bool, Bool}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any}, Zygote.Pullback{Tuple{Type{LinearAlgebra.MulAddMul}, Bool, Bool}, Any}}}}}, Zygote.ZBack{ChainRules.var"#similar_pullback#914"{Tuple{Vector{TaylorScalar{Float64, 2}}, DataType, Base.OneTo{Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(LinearAlgebra.matprod), Type{Float64}, Type{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(LinearAlgebra.matprod), Type{Tuple{Float64, TaylorScalar{Float64, 2}}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}}}, Zygote.ZBack{ChainRules.var"#axes_pullback#305"}}})(Δ::FillArrays.Fill{TaylorScalar{Float64, 2}, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
 [11] Pullback
    @ ./REPL[147]:1 [inlined]
 [12] (::Zygote.Pullback{Tuple{typeof(mlp), Vector{TaylorScalar{Float64, 2}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(*), Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, Bool, Bool}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any}, Zygote.Pullback{Tuple{Type{LinearAlgebra.MulAddMul}, Bool, Bool}, Any}}}}}, Zygote.ZBack{ChainRules.var"#similar_pullback#914"{Tuple{Vector{TaylorScalar{Float64, 2}}, DataType, Base.OneTo{Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(LinearAlgebra.matprod), Type{Float64}, Type{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(LinearAlgebra.matprod), Type{Tuple{Float64, TaylorScalar{Float64, 2}}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}}}, Zygote.ZBack{ChainRules.var"#axes_pullback#305"}}}}})(Δ::FillArrays.Fill{TaylorScalar{Float64, 2}, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./REPL[149]:1 [inlined]
 [14] (::Zygote.Pullback{Tuple{var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(mlp), Vector{TaylorScalar{Float64, 2}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(*), Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, Bool, Bool}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any}, Zygote.Pullback{Tuple{Type{LinearAlgebra.MulAddMul}, Bool, Bool}, Any}}}}}, Zygote.ZBack{ChainRules.var"#similar_pullback#914"{Tuple{Vector{TaylorScalar{Float64, 2}}, DataType, Base.OneTo{Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(LinearAlgebra.matprod), Type{Float64}, Type{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(LinearAlgebra.matprod), Type{Tuple{Float64, TaylorScalar{Float64, 2}}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}}}, Zygote.ZBack{ChainRules.var"#axes_pullback#305"}}}}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:θ, Zygote.Context{false}, var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}}, Zygote.var"#3011#back#778"{Zygote.var"#772#776"{Vector{TaylorScalar{Float64, 2}}}}}})(Δ::TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/TaylorDiff/EoBny/src/derivative.jl:37 [inlined]
 [16] Pullback
    @ ~/.julia/packages/TaylorDiff/EoBny/src/derivative.jl:23 [inlined]
 [17] Pullback
    @ ./REPL[149]:1 [inlined]
 [18] (::Zygote.Pullback{Tuple{var"#305#307", ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{2, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.var"#1974#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, Vector{Float64}}}, Zygote.Pullback{Tuple{typeof(TaylorDiff.derivative), var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{Float64}, Vector{Float64}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(TaylorDiff.derivative), var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{Float64}, Vector{Float64}, Val{2}}, Tuple{Zygote.ZBack{TaylorDiff.var"#extract_derivative_pullback#125"{2, Float64, Int64}}, Zygote.var"#2865#back#684"{Zygote.var"#map_back#678"{TaylorDiff.var"#118#119"{Val{2}}, 2, Tuple{Vector{Float64}, Vector{Float64}}, Tuple{Tuple{Base.OneTo{Int64}}, Tuple{Base.OneTo{Int64}}}, Vector{Tuple{TaylorScalar{Float64, 2}, Zygote.Pullback{Tuple{TaylorDiff.var"#118#119"{Val{2}}, Float64, Float64}, Tuple{Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:vN, Zygote.Context{false}, TaylorDiff.var"#118#119"{Val{2}}, Val{2}}}, Zygote.Pullback{Tuple{typeof(TaylorDiff.make_taylor), Float64, Float64, Val{2}}, Tuple{Zygote.Pullback{Tuple{Type{TaylorScalar{Float64, 2}}, Float64, Float64}, Tuple{Zygote.Pullback{Tuple{Type{TaylorScalar}, Tuple{Float64, Float64}}, Tuple{Zygote.ZBack{TaylorDiff.var"#taylor_scalar_pullback#120"}}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.ZBack{Zygote.var"#Real_pullback#328"}, Zygote.ZBack{Zygote.var"#Real_pullback#328"}}}}}}}}}}}, Zygote.var"#2198#back#309"{Zygote.Jnew{TaylorDiff.var"#118#119"{Val{2}}, Nothing, false}}, Zygote.Pullback{Tuple{var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(mlp), Vector{TaylorScalar{Float64, 2}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(*), Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, Bool, Bool}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any}, Zygote.Pullback{Tuple{Type{LinearAlgebra.MulAddMul}, Bool, Bool}, Any}}}}}, Zygote.ZBack{ChainRules.var"#similar_pullback#914"{Tuple{Vector{TaylorScalar{Float64, 2}}, DataType, Base.OneTo{Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(LinearAlgebra.matprod), Type{Float64}, Type{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(LinearAlgebra.matprod), Type{Tuple{Float64, TaylorScalar{Float64, 2}}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}}}, Zygote.ZBack{ChainRules.var"#axes_pullback#305"}}}}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:θ, Zygote.Context{false}, var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}}, Zygote.var"#3011#back#778"{Zygote.var"#772#776"{Vector{TaylorScalar{Float64, 2}}}}}}}}, Zygote.var"#1910#back#157"{Zygote.var"#153#156"}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Int64, Int64}}}}}, Zygote.var"#2198#back#309"{Zygote.Jnew{var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Nothing, false}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#305#307", ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{2, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.var"#1974#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, Vector{Float64}}}, Zygote.Pullback{Tuple{typeof(TaylorDiff.derivative), var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{Float64}, Vector{Float64}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(TaylorDiff.derivative), var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{Float64}, Vector{Float64}, Val{2}}, Tuple{Zygote.ZBack{TaylorDiff.var"#extract_derivative_pullback#125"{2, Float64, Int64}}, Zygote.var"#2865#back#684"{Zygote.var"#map_back#678"{TaylorDiff.var"#118#119"{Val{2}}, 2, Tuple{Vector{Float64}, Vector{Float64}}, Tuple{Tuple{Base.OneTo{Int64}}, Tuple{Base.OneTo{Int64}}}, Vector{Tuple{TaylorScalar{Float64, 2}, Zygote.Pullback{Tuple{TaylorDiff.var"#118#119"{Val{2}}, Float64, Float64}, Tuple{Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:vN, Zygote.Context{false}, TaylorDiff.var"#118#119"{Val{2}}, Val{2}}}, Zygote.Pullback{Tuple{typeof(TaylorDiff.make_taylor), Float64, Float64, Val{2}}, Tuple{Zygote.Pullback{Tuple{Type{TaylorScalar{Float64, 2}}, Float64, Float64}, Tuple{Zygote.Pullback{Tuple{Type{TaylorScalar}, Tuple{Float64, Float64}}, Tuple{Zygote.ZBack{TaylorDiff.var"#taylor_scalar_pullback#120"}}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.ZBack{Zygote.var"#Real_pullback#328"}, Zygote.ZBack{Zygote.var"#Real_pullback#328"}}}}}}}}}}}, Zygote.var"#2198#back#309"{Zygote.Jnew{TaylorDiff.var"#118#119"{Val{2}}, Nothing, false}}, Zygote.Pullback{Tuple{var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(mlp), Vector{TaylorScalar{Float64, 2}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(*), Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.mul!), Vector{TaylorScalar{Float64, 2}}, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, Bool, Bool}, Tuple{Zygote.Pullback{Tuple{typeof(LinearAlgebra.generic_matvecmul!), Vector{TaylorScalar{Float64, 2}}, Char, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{TaylorScalar{Float64, 2}}, LinearAlgebra.MulAddMul{true, true, Bool, Bool}}, Any}, Zygote.Pullback{Tuple{Type{LinearAlgebra.MulAddMul}, Bool, Bool}, Any}}}}}, Zygote.ZBack{ChainRules.var"#similar_pullback#914"{Tuple{Vector{TaylorScalar{Float64, 2}}, DataType, Base.OneTo{Int64}}}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(LinearAlgebra.matprod), Type{Float64}, Type{TaylorScalar{Float64, 2}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(LinearAlgebra.matprod), Type{Tuple{Float64, TaylorScalar{Float64, 2}}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}}}, Zygote.ZBack{ChainRules.var"#axes_pullback#305"}}}}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:θ, Zygote.Context{false}, var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}}, Zygote.var"#3011#back#778"{Zygote.var"#772#776"{Vector{TaylorScalar{Float64, 2}}}}}}}}, Zygote.var"#1910#back#157"{Zygote.var"#153#156"}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Int64, Int64}}}}}, Zygote.var"#2198#back#309"{Zygote.Jnew{var"#306#308"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}}}, Nothing, false}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface.jl:45
 [20] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(W = ViewAxis(1:32, ShapedAxis((16, 2), NamedTuple())),)}}})
    @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface.jl:97
 [21] top-level scope
    @ REPL[149]:1
YichengDWu commented 1 year ago

Fixed in https://github.com/JuliaDiff/TaylorDiff.jl/pull/38