Open lkdvos opened 10 months ago
Attention: 5 lines
in your changes are missing coverage. Please review.
Comparison is base (
517e676
) 93.67% compared to head (5ba0421
) 91.53%.
Files | Patch % | Lines |
---|---|---|
ext/StridedViewsFillArraysExt.jl | 64.28% | 5 Missing :warning: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Hi @lkdvos, thank you very much for this PR, like you mentioned I came across the StridedView
constructor not being available for FillArray
in an AD situation (see MWE below), I tried using this branch to check if it resolves the below example, but deep down it throws another error regarding the lack of pointer
function for FillArray
(but only in one case below), called within unsafe_convert
on a StridedView
. Below is the example and the stack trace when using this branch:
using TensorOperations, StridedViews, Zygote
A = rand(4,3,2)
x = rand(2)
function f(x)
TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
sum(B)
end
function g(x)
TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
sum(B,dims=2)
end
Zygote.gradient(f,x) # doesn't work
Zygote.jacobian(g,x) # works and matches with ForwardDiff
# Following is the error stack trace
julia> Zygote.jacobian(f,x)
ERROR: conversion to pointer not defined for FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Base ./pointer.jl:67
[3] pointer
@ ./abstractarray.jl:1245 [inlined]
[4] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)})
@ StridedViews ~/.julia/packages/StridedViews/cN0vi/src/stridedview.jl:191
[5] gemm!(transA::Char, transB::Char, alpha::Float64, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, beta::Float64, C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)})
@ LinearAlgebra.BLAS ~/julia-src/julia-1.9.3/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
[6] _threaded_blas_mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero, nthreads::Int64)
@ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:105
[7] _mul!
@ ~/.julia/packages/Strided/l1vm3/src/linalg.jl:91 [inlined]
[8] mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero)
@ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:60
[9] _unsafe_blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, ipC::Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:163
[10] blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:137
[11] tensorcontract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.Backend{:StridedBLAS})
@ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:65
[12] tensorcontract!
@ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:63 [inlined]
[13] tensorcontract!
@ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:35 [inlined]
[14] #62
@ ~/.julia/packages/TensorOperations/LAzcX/ext/TensorOperationsChainRulesCoreExt.jl:99 [inlined]
[15] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
[16] wrap_chainrules_output
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:110 [inlined]
[17] map (repeats 4 times)
@ ./tuple.jl:276 [inlined]
[18] wrap_chainrules_output
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:111 [inlined]
[19] ZBack
@ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:211 [inlined]
[20] Pullback
@ ./REPL[54]:2 [inlined]
[21] (::Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[22] #291
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
[23] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[24] Pullback
@ ./operators.jl:1035 [inlined]
[25] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[26] Pullback
@ ./operators.jl:1034 [inlined]
[27] Pullback
@ ./operators.jl:1031 [inlined]
[28] (::Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[29] #291
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
[30] #2173#back
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[31] Pullback
@ ./operators.jl:1031 [inlined]
[32] (::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[33] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
[34] withjacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:150
[35] jacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:128
[36] top-level scope
@ REPL[63]:1
I thought it would be relevant to this PR so posting here, but if requires a new issue, let me know, thank you.
Thanks for reporting this! I think my PR indeed requires a bit more work, as multiplication should not be dispatched through to BLAS like that, and your example is the exact reason I started looking into this. I'll try and make some time to look further into it next week, and I hope to add this as a test case to the TensorOperations suite.
Great, thank you very much, that would be very helpful.
This removes the hard restriction on the parent arrays subtyping
DenseArray
, but only defines a default constructor forStridedView(::DenseArray)
. (note that the current implementation does however have a constructorStridedView(::AbstractArray, size, strides, offset, op)
, which is useful in order to not have to redefine methods likeconj
etc for non-DenseArray
StridedViews.)Additionally, it defines a package extension (which I currently have implemented solely as an extension, and not through requires for julia < v1.9) to be able to handle FillArrays, which shows up sometimes in for example Zygotes automatic differentiation rules.
This fixes #2